Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Separate Q and KV dtypes for decode #286

Merged
merged 10 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 59 additions & 54 deletions include/flashinfer/attention/decode.cuh

Large diffs are not rendered by default.

30 changes: 16 additions & 14 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ namespace flashinfer {

template <bool partition_kv, PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem,
uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ, typename DTypeKV,
typename DTypeOut, typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeIn* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse,
bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale,
Expand Down Expand Up @@ -86,7 +86,7 @@ std::pair<uint32_t, uint32_t> PartitionPagedKVCacheBinarySearchMinNumPagePerBatc
* \brief Estimate the temporary buffer size and the maximum grid size for the
* partition-kv BatchDecodeWithPagedKVCache kernel
* \tparam page_storage Whether to store indices or pointers of each active page
* \tparam DTypeIn A template type indicates the input data type
* \tparam DTypeKV A template type indicates the key-value data type
* \tparam DTypeOut A template type indicates the output data type
* \tparam IdType A template type indicates the index data type
* \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel
Expand All @@ -100,27 +100,29 @@ std::pair<uint32_t, uint32_t> PartitionPagedKVCacheBinarySearchMinNumPagePerBatc
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads,
const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) {
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL);
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
constexpr uint32_t num_stages_smem = 2U;
constexpr uint32_t bdx = HEAD_DIM / vec_size;
static_assert(bdx <= 32);
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 4U) : 1U;
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U;
const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
const uint32_t smem_size =
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float));
2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));

// Note that the dtype of Q should not impact the cudaOccupancyMaxActiveBlocksPerMultiprocessor
// return, which is why we just use DTypeKV as it simplifies the API.
auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel<
/*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx,
bdy, bdz, page_storage, kv_layout, DTypeIn, DTypeOut, IdType>;
bdy, bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
Expand Down Expand Up @@ -294,7 +296,7 @@ class BatchDecodeHandler {
bool* GetBlockValidMask() const { return block_valid_mask_; }

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr,
IdType* last_page_len, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t page_size) {
Expand All @@ -303,8 +305,8 @@ class BatchDecodeHandler {
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimationDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
kv_layout, POS_ENCODING_MODE, DTypeIn,
DTypeOut, IdType>;
kv_layout, POS_ENCODING_MODE,
DTypeQ, DTypeKV, DTypeOut, IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch,
new_batch_size, batch_size, indptr, num_qo_heads,
page_size,
Expand Down
22 changes: 11 additions & 11 deletions include/flashinfer/decode_attention_decl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,35 @@
namespace flashinfer {

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode pos_encoding_mode, typename DTypeIn, typename DTypeOut>
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
PosEncodingMode pos_encoding_mode, typename DTypeQ, typename DTypeKV, typename DTypeOut>
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, uint32_t num_kv_heads,
uint32_t seq_len, float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream);

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PageStorage page_storage, QKVLayout kv_layout,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeIn* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, kv_layout, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale,
float rope_scale, float rope_theta, cudaStream_t stream);

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut>
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream);

template <PageStorage page_storage, QKVLayout KV_LAYOUT, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
PosEncodingMode POS_ENCODING_MODE, typename DTypeIn, typename DTypeOut, typename IdType>
PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset,
paged_kv_t<page_storage, KV_LAYOUT, DTypeKV, IdType> paged_kv, DTypeOut* o, float* lse,
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
paged_kv_t<page_storage, KV_LAYOUT, DTypeIn, IdType> new_paged_kv = paged_kv;
paged_kv_t<page_storage, KV_LAYOUT, DTypeKV, IdType> new_paged_kv = paged_kv;
kv_partition_info_t<IdType> kv_partition_info;
DTypeOut* tmp_v = handler->GetTempV<DTypeOut>();
float* tmp_s = handler->GetTempS<float>();
Expand All @@ -82,7 +82,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched(
}

return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage, KV_LAYOUT,
POS_ENCODING_MODE, DTypeIn, DTypeOut, IdType>(
POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>(
q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse,
handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta,
stream);
Expand Down
Loading