From b47122b809b7c4cede02bb81728dd0bde43fa458 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 5 Jun 2024 15:20:55 -0700 Subject: [PATCH 1/9] Separate Q and KV dtypes for decode --- include/flashinfer/attention/decode.cuh | 113 ++++++----- include/flashinfer/decode_attention_decl.cuh | 22 +-- python/csrc/batch_decode.cu | 196 ++++++++++--------- python/csrc/pytorch_extension_utils.h | 39 ++-- python/csrc/single_decode.cu | 2 +- python/generate_batch_padded_decode_inst.py | 12 +- python/generate_batch_paged_decode_inst.py | 13 +- python/generate_single_decode_inst.py | 11 +- python/setup.py | 71 ++----- src/flashinfer_ops.cuh | 39 ++-- 10 files changed, 257 insertions(+), 261 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index e63d068d..fde71c2f 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -184,7 +184,8 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeQ A template type indicates the query data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \param q [num_qo_heads, head_dim] The query matrix * \param k [seq_len, num_kv_heads, head_dim] The key matrix in kv-cache @@ -203,9 +204,9 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f */ template -__global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, - DTypeIn* __restrict__ v, DTypeOut* __restrict__ o, + uint32_t bdy, uint32_t bdz, typename DTypeQ, typename DTypeKV, typename DTypeOut> +__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, + DTypeKV* __restrict__ v, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp, tensor_info_t info, float sm_scale, float rope_rcp_scale, @@ -224,11 +225,11 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* uint32_t seq_len = info.kv_len; extern __shared__ uint8_t smem[]; - DTypeIn* k_smem = (DTypeIn*)smem; - DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * - sizeof(DTypeIn)); + DTypeKV* k_smem = (DTypeKV*)smem; + DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * + sizeof(DTypeKV)); float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * - sizeof(DTypeIn)); + sizeof(DTypeKV)); uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; @@ -260,7 +261,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* // preload k tiles and v tiles uint32_t producer_kv_idx_base = chunk_start; - constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { @@ -356,10 +357,10 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* } template + uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename DTypeQ, + typename DTypeKV, typename DTypeOut> __global__ void BatchDecodeWithPaddedKVCacheKernel( - DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, + DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v, DTypeOut* __restrict__ o, float* __restrict__ lse, tensor_info_t info, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { @@ -376,9 +377,9 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( uint32_t seq_len = info.kv_len; extern __shared__ uint8_t smem[]; - DTypeIn* k_smem = (DTypeIn*)smem; - DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn)); - float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn)); + DTypeKV* k_smem = (DTypeKV*)smem; + DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV)); + float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV)); uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; @@ -407,7 +408,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( // preload k tiles and v tiles uint32_t producer_kv_idx_base = 0; - constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { cp_async::pred_load( @@ -495,7 +496,8 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( * \tparam bdy A template integer indicates the block size in y dimension * \tparam bdz A template integer indicates the block size in z dimension * \tparam page_storage Whether to store indices or pointers of each active page - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeQ A template type indicates the query data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \tparam IdType A template type indicates the index data type * \param q [batch_size, num_qo_heads, head_dim] The query matrix @@ -512,11 +514,11 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( */ template + PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ, typename DTypeKV, + typename DTypeOut, typename IdType> __global__ void BatchDecodeWithPagedKVCacheKernel( - DTypeIn* __restrict__ q, IdType* __restrict__ q_offset, - paged_kv_t paged_kv, + DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, + paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { @@ -544,13 +546,13 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx; extern __shared__ uint8_t smem[]; - DTypeIn* k_smem = (DTypeIn*)smem; - DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * - sizeof(DTypeIn)); - DTypeIn** k_ptrs_smem = (DTypeIn**)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * - head_dim * sizeof(DTypeIn)); + DTypeKV* k_smem = (DTypeKV*)smem; + DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * + sizeof(DTypeKV)); + DTypeKV** k_ptrs_smem = (DTypeKV**)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * + head_dim * sizeof(DTypeKV)); float* smem_md = (float*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * - sizeof(DTypeIn)); + sizeof(DTypeKV)); const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; @@ -578,7 +580,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( // preload k/v tiles uint32_t stage_idx = 0; - constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; static_assert(num_stages_smem <= bdx); @@ -590,7 +592,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( } block.sync(); - DTypeIn* k_ptrs[tile_size_per_bdx]; + DTypeKV* k_ptrs[tile_size_per_bdx]; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { #pragma unroll @@ -608,7 +610,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( cp_async::commit_group(); #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - DTypeIn* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta(); + DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta(); cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, @@ -677,7 +679,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( // load v tiles #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - DTypeIn* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta(); + DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta(); cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, @@ -731,7 +733,8 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo /*! * \brief FlashAttention decoding with kv-cache for a single request - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeQ A template type indicates the query data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \param q The query matrix, shape: [num_qo_heads, head_dim] * \param k The key matrix in kv-cache, shape: [seq_len, num_kv_heads, head_dim] @@ -752,33 +755,34 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo * \return status Indicates whether CUDA calls are successful */ template -cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, + typename DTypeOut> +cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32U); constexpr uint32_t bdy = GROUP_SIZE; constexpr uint32_t num_threads = - std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)), bdx * bdy); + std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeKV)), bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); tensor_info_t info(1, seq_len, num_kv_heads); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 8U) : 1U; + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; const uint32_t smem_size = - 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeIn) + + 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + 2U * bdy * bdz * sizeof(float); if (seq_len <= 256 || tmp == nullptr) { // no need to use partition-kv kernel auto kernel = SingleDecodeWithKVCacheKernel; + DTypeQ, DTypeKV, DTypeOut>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -799,7 +803,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v // use partition-kv kernel auto kernel = SingleDecodeWithKVCacheKernel; + bdy, bdz, DTypeQ, DTypeKV, DTypeOut>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -839,9 +843,9 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v } template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; @@ -850,17 +854,17 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( const uint32_t batch_size = paged_kv.batch_size; const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); constexpr uint32_t bdy = GROUP_SIZE; constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 4U) : 1U; + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; const uint32_t smem_size = - 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float)); + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); if (tmp == nullptr) { // do not use partition-kv kernel @@ -869,7 +873,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( auto kernel = BatchDecodeWithPagedKVCacheKernel; + bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, @@ -888,7 +892,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel; + kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, @@ -916,7 +920,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( /*! * \brief FlashAttention decoding cuda kernel with paged kv-cache for batched requests * \tparam page_storage Whether to store indices or pointers of each active page - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeQ A template type indicates the query data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \tparam IdType A template type indicates the index data type used in paged kv-cache * \param q [batch_size, num_qo_heads, head_dim] The query matrix @@ -932,8 +937,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( * \return status Indicates whether CUDA calls are successful */ template -cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut> +cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, float sm_scale, float rope_scale, @@ -942,7 +947,7 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType const float rope_rcp_theta = 1.f / rope_theta; const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); @@ -951,12 +956,12 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType constexpr uint32_t bdz = num_threads / (bdx * bdy); const uint32_t smem_size = - 2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) + 2 * bdy * bdz * sizeof(float); + 2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + 2 * bdy * bdz * sizeof(float); dim3 nblks(batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); auto kernel = BatchDecodeWithPaddedKVCacheKernel; + vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); tensor_info_t info(1, padded_kv_len, num_kv_heads); diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index 1cd96bd9..ae9e29f0 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -29,34 +29,34 @@ namespace flashinfer { template -cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, + PosEncodingMode pos_encoding_mode, typename DTypeQ, typename DTypeKV, typename DTypeOut> +cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template -cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut> +cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( - BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, + BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, + paged_kv_t paged_kv, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { - paged_kv_t new_paged_kv = paged_kv; + paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; DTypeOut* tmp = handler->GetTempFloatBuffer(); @@ -80,7 +80,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( } return BatchDecodeWithPagedKVCacheDispatched( + POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, rope_theta, stream); } diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index b58b5efb..20041905 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -33,8 +33,7 @@ std::vector batch_decode_with_padded_kv_cache( CHECK_SHAPE(k_padded, v_padded); CHECK_EQ(q.size(0), k_padded.size(0)); CHECK_EQ(q.size(2), k_padded.size(3)); - CHECK_EQ(q.scalar_type(), k_padded.scalar_type()); - CHECK_EQ(q.scalar_type(), v_padded.scalar_type()); + CHECK_EQ(v_padded.scalar_type(), k_padded.scalar_type()); unsigned int batch_size = q.size(0); unsigned int num_qo_heads = q.size(1); unsigned int head_dim = q.size(2); @@ -58,53 +57,57 @@ std::vector batch_decode_with_padded_kv_cache( } if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - nv_half* tmp = nullptr; - cudaError_t status = - BatchDecodeWithPaddedKVCacheDispatched( - static_cast(q.data_ptr()), - static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), - static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE(k_padded.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + nv_half* tmp = nullptr; + cudaError_t status = + BatchDecodeWithPaddedKVCacheDispatched( + static_cast(q.data_ptr()), + static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), + static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPaddedKVCache failed with error code ", status); + return true; + }); }); - }); + }); }); }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - c_type* tmp = nullptr; - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - cudaError_t status = BatchDecodeWithPaddedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type>( - static_cast(q.data_ptr()), static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE(k_padded.scalar_type(), kv_type, [&] { + q_type* tmp = nullptr; + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + cudaError_t status = BatchDecodeWithPaddedKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, q_type>( + static_cast(q.data_ptr()), static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPaddedKVCache failed with error code ", status); + return true; + }); }); - }); + }); }); }); }); @@ -137,7 +140,7 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( handler_->SetCUDAStream(torch_current_stream); if (is_float8_tensor(empty_data)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { @@ -208,7 +211,6 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( CHECK_DIM(1, paged_kv_last_page_len); // (B,) CHECK_DIM(1, paged_kv_indptr); // (B+1,) CHECK_DIM(1, paged_kv_indices); // (nnz,) - CHECK_EQ(q.scalar_type(), paged_kv_data.scalar_type()); // (num_max_pages, 2, H_kv, page_size, head_dim) for HND // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD CHECK_DIM(5, paged_kv_data); @@ -242,61 +244,65 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( } if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, - c_type, nv_half, int32_t>( - handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, - paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE(paged_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, + q_type, kv_type, nv_half, int32_t>( + handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, + paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); }); }); }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, - c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, - paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE(paged_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, + q_type, kv_type, q_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, + paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); }); }); }); diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 8d5a6952..e3bcaaee 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -31,7 +31,7 @@ using namespace flashinfer; -#ifdef FLASHINFER_ENABLE_BF16 +#if defined (FLASHINFER_ENABLE_BF16) && defined (FLASHINFER_ENABLE_FP8) #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ @@ -43,6 +43,14 @@ using namespace flashinfer; using c_type = nv_bfloat16; \ return __VA_ARGS__(); \ } \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ @@ -50,7 +58,7 @@ using namespace flashinfer; return false; \ } \ }() -#else +#elif defined (FLASHINFER_ENABLE_BF16) #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ @@ -58,6 +66,10 @@ using namespace flashinfer; using c_type = nv_half; \ return __VA_ARGS__(); \ } \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ default: \ std::ostringstream oss; \ oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ @@ -65,10 +77,8 @@ using namespace flashinfer; return false; \ } \ }() -#endif - -#ifdef FLASHINFER_ENABLE_FP8 -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ +#elif defined (FLASHINFER_ENABLE_FP8) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Float8_e4m3fn: { \ @@ -87,12 +97,19 @@ using namespace flashinfer; } \ }() #else -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ }() #endif diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index 9172fd8b..6a767f91 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -51,7 +51,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { diff --git a/python/generate_batch_padded_decode_inst.py b/python/generate_batch_padded_decode_inst.py index 1ef596d4..fa5fb973 100644 --- a/python/generate_batch_padded_decode_inst.py +++ b/python/generate_batch_padded_decode_inst.py @@ -29,15 +29,16 @@ def get_cu_file_str( head_dim, kv_layout, pos_encoding_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, ): content = """#include namespace flashinfer {{ -template cudaError_t BatchDecodeWithPaddedKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_in}, {dtype_out}>( - {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( + {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, {dtype_out}* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, float sm_scale, float rope_scale, @@ -49,7 +50,8 @@ def get_cu_file_str( group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - dtype_in=dtype_literal[dtype_in], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], ) return content @@ -58,7 +60,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_padded_decode_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" + r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/generate_batch_paged_decode_inst.py b/python/generate_batch_paged_decode_inst.py index e9c06605..56fe44aa 100644 --- a/python/generate_batch_paged_decode_inst.py +++ b/python/generate_batch_paged_decode_inst.py @@ -26,7 +26,7 @@ def get_cu_file_str( - group_size, head_dim, kv_layout, pos_encoding_mode, dtype_in, dtype_out, idtype + group_size, head_dim, kv_layout, pos_encoding_mode, dtype_q, dtype_kv, dtype_out, idtype ): content = """#include @@ -34,9 +34,9 @@ def get_cu_file_str( constexpr PageStorage page_storage = PageStorage::kIndices; -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{group_size}, {head_dim}, page_storage, {kv_layout}, {pos_encoding_mode}, {dtype_in}, {dtype_out}, {idtype}>( - {dtype_in}* q, {idtype}* q_offset, - paged_kv_t paged_kv, +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{group_size}, {head_dim}, page_storage, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv} {dtype_out}, {idtype}>( + {dtype_q}* q, {idtype}* q_offset, + paged_kv_t paged_kv, kv_partition_info_t<{idtype}> kv_partition_info, {dtype_out}* o, {dtype_out}* tmp, float* lse, float sm_scale, float rope_scale, @@ -48,7 +48,8 @@ def get_cu_file_str( group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - dtype_in=dtype_literal[dtype_in], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], idtype=idtype_literal[idtype], ) @@ -58,7 +59,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_paged_decode_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/generate_single_decode_inst.py b/python/generate_single_decode_inst.py index 67d417c0..8fc36218 100644 --- a/python/generate_single_decode_inst.py +++ b/python/generate_single_decode_inst.py @@ -21,14 +21,14 @@ def get_cu_file_str( - group_size, head_dim, kv_layout, pos_encoding_mode, dtype_in, dtype_out + group_size, head_dim, kv_layout, pos_encoding_mode, dtype_q, dtype_kv, dtype_out ): content = """#include namespace flashinfer {{ -template cudaError_t SingleDecodeWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_in}, {dtype_out}>( - {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, {dtype_out}* o, +template cudaError_t SingleDecodeWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( + {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, {dtype_out}* tmp, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); @@ -39,7 +39,8 @@ def get_cu_file_str( group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - dtype_in=dtype_literal[dtype_in], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], ) return content @@ -48,7 +49,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"single_decode_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" + r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/setup.py b/python/setup.py index fff016b7..c758f394 100644 --- a/python/setup.py +++ b/python/setup.py @@ -94,12 +94,12 @@ def get_instantiation_cu() -> List[str]: idtypes = ["i32"] prefill_dtypes = ["f16"] decode_dtypes = ["f16"] + fp8_dtypes = ["e4m3", "e5m2"] if enable_bf16: prefill_dtypes.append("bf16") decode_dtypes.append("bf16") - fp8_dtypes = [] if enable_fp8: - fp8_dtypes = ["e4m3", "e5m2"] + decode_dtypes.extend(fp8_dtypes) files = [] # single decode files @@ -114,29 +114,17 @@ def get_instantiation_cu() -> List[str]: kv_layouts, pos_encoding_modes, ): - for dtype in decode_dtypes: - fname = f"single_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" + for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): + dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" + fname = f"single_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" files.append(prefix + "/" + fname) content = generate_single_decode_inst.get_cu_file_str( group_size, head_dim, kv_layout, pos_encoding_mode, - dtype, - dtype, - ) - write_if_different(root / prefix / fname, content) - - for dtype_in in fp8_dtypes: - dtype_out = "f16" - fname = f"single_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype_in}_dtypeout_{dtype_out}.cu" - files.append(prefix + "/" + fname) - content = generate_single_decode_inst.get_cu_file_str( - group_size, - head_dim, - kv_layout, - pos_encoding_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, ) write_if_different(root / prefix / fname, content) @@ -154,58 +142,33 @@ def get_instantiation_cu() -> List[str]: pos_encoding_modes, ): for idtype in idtypes: - for dtype in decode_dtypes: - fname = f"batch_paged_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" - files.append(prefix + "/" + fname) - content = generate_batch_paged_decode_inst.get_cu_file_str( - group_size, - head_dim, - kv_layout, - pos_encoding_mode, - dtype, - dtype, - idtype, - ) - write_if_different(root / prefix / fname, content) - - for dtype_in in fp8_dtypes: - dtype_out = "f16" - fname = f"batch_paged_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype_in}_dtypeout_{dtype_out}_idtype_{idtype}.cu" + for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): + dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" + fname = f"batch_paged_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_decode_inst.get_cu_file_str( group_size, head_dim, kv_layout, pos_encoding_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, idtype, ) write_if_different(root / prefix / fname, content) - for dtype in decode_dtypes: - fname = f"batch_padded_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" - files.append(prefix + "/" + fname) - content = generate_batch_padded_decode_inst.get_cu_file_str( - group_size, - head_dim, - kv_layout, - pos_encoding_mode, - dtype, - dtype, - ) - write_if_different(root / prefix / fname, content) - - for dtype_in in fp8_dtypes: - dtype_out = "f16" - fname = f"batch_padded_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype_in}_dtypeout_{dtype_out}.cu" + for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): + dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" + fname = f"batch_padded_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" files.append(prefix + "/" + fname) content = generate_batch_padded_decode_inst.get_cu_file_str( group_size, head_dim, kv_layout, pos_encoding_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, ) write_if_different(root / prefix / fname, content) diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index dc48fb90..d7f2f8ca 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -144,8 +144,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( return cudaSuccess; } -template -cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, DTypeOut* tmp, +template +cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, @@ -174,8 +174,8 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut return cudaSuccess; } -template -cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, +template +cudaError_t BatchDecodeWithPaddedKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, @@ -199,17 +199,17 @@ cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTy {DISPATCH_pos_encoding_mode( pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { return BatchDecodeWithPaddedKVCacheDispatched( + POS_ENCODING_MODE, DTypeQ, DtypeKV, DTypeOut>( q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, stream); })})})}); return cudaSuccess; } -template +template cudaError_t BatchDecodeWithPagedKVCache( - DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, @@ -230,8 +230,8 @@ cudaError_t BatchDecodeWithPagedKVCache( {DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return BatchDecodeWithPagedKVCacheDispatched( + kv_layout, POS_ENCODING_MODE, DTypeQ, + DTypeKV, DTypeOut, IdType>( q, q_offset, paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, rope_theta, stream); })})}); @@ -244,7 +244,8 @@ cudaError_t BatchDecodeWithPagedKVCache( * for cooperative kernels. * \tparam page_storage Whether to store indices or pointers of each active page * \tparam kv_layout The layout of last 3 dimensions in KV-Cache - * \tparam DTypeIn The data type of input tensor. + * \tparam DTypeQ The data type of query tensor. + * \tparam DTypeKV The data type of key-value tensor. * \tparam DTypeOut The data type of output tensor. * \tparam IdType The data type of index tensor. * \param handler The handler for the batch decode forward request. @@ -260,11 +261,11 @@ cudaError_t BatchDecodeWithPagedKVCache( * \note This wrapper function should be only called after we call BeginForward function in the * BatchDecodeHandler. */ -template +template cudaError_t BatchDecodeWithPagedKVCacheWrapper( - BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, + BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, + paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { @@ -284,14 +285,14 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return BatchDecodeWithPagedKVCacheWrapperDispatched( + DTypeQ, DTypeKV, DTypeOut, IdType>( handler, q, q_offset, paged_kv, o, lse, sm_scale, rope_scale, rope_theta, stream); })})}); return cudaSuccess; } -template +template cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, @@ -308,7 +309,7 @@ cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* bu DISPATCH_head_dim(head_dim, HEAD_DIM, { DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return handler->BeginForwardDispatched( + POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( buffer, workspace_size_in_bytes, indptr, last_page_len, batch_size, num_qo_heads, page_size); }); From 553831ec7b34ee524c0bbaf11228cbeb65a200fa Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 5 Jun 2024 15:45:59 -0700 Subject: [PATCH 2/9] Fix --- python/generate_batch_paged_decode_inst.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/generate_batch_paged_decode_inst.py b/python/generate_batch_paged_decode_inst.py index 56fe44aa..a8b099ad 100644 --- a/python/generate_batch_paged_decode_inst.py +++ b/python/generate_batch_paged_decode_inst.py @@ -34,7 +34,7 @@ def get_cu_file_str( constexpr PageStorage page_storage = PageStorage::kIndices; -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{group_size}, {head_dim}, page_storage, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv} {dtype_out}, {idtype}>( +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{group_size}, {head_dim}, page_storage, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>( {dtype_q}* q, {idtype}* q_offset, paged_kv_t paged_kv, kv_partition_info_t<{idtype}> kv_partition_info, From 212c0a540f93e733db55019770b3224e464158c3 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 6 Jun 2024 15:44:17 -0700 Subject: [PATCH 3/9] WIP --- python/csrc/batch_decode.cu | 20 +++---- python/csrc/pytorch_extension_utils.h | 74 +++++++++++++++++++++++-- python/csrc/single_decode.cu | 79 ++++++++++++++------------- 3 files changed, 121 insertions(+), 52 deletions(-) diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 20041905..2b13ebc3 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -57,8 +57,8 @@ std::vector batch_decode_with_padded_kv_cache( } if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE(k_padded.scalar_type(), kv_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_pos_encoding_mode( @@ -86,8 +86,8 @@ std::vector batch_decode_with_padded_kv_cache( }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE(k_padded.scalar_type(), kv_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { q_type* tmp = nullptr; return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { @@ -140,7 +140,7 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( handler_->SetCUDAStream(torch_current_stream); if (is_float8_tensor(empty_data)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_data.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { @@ -164,7 +164,7 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_data.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { @@ -244,8 +244,8 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( } if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE(paged_kv_data.scalar_type(), kv_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(paged_kv_data.scalar_type(), kv_type, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { @@ -276,8 +276,8 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { - return DISPATCH_PYTORCH_DTYPE_TO_CTYPE(paged_kv_data.scalar_type(), kv_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(paged_kv_data.scalar_type(), kv_type, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index e3bcaaee..f67fa369 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -31,8 +31,74 @@ using namespace flashinfer; -#if defined (FLASHINFER_ENABLE_BF16) && defined (FLASHINFER_ENABLE_FP8) + +#ifdef FLASHINFER_ENABLE_BF16 #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#else +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#endif + +#ifdef FLASHINFER_ENABLE_FP8 +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#else +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + }() +#endif + +#if defined (FLASHINFER_ENABLE_BF16) && defined (FLASHINFER_ENABLE_FP8) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ @@ -59,7 +125,7 @@ using namespace flashinfer; } \ }() #elif defined (FLASHINFER_ENABLE_BF16) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ @@ -78,7 +144,7 @@ using namespace flashinfer; } \ }() #elif defined (FLASHINFER_ENABLE_FP8) -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Float8_e4m3fn: { \ @@ -97,7 +163,7 @@ using namespace flashinfer; } \ }() #else -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ switch (pytorch_dtype) { \ case at::ScalarType::Half: { \ diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index 6a767f91..608f1517 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -32,8 +32,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc CHECK_DIM(3, v); CHECK_SHAPE(k, v); CHECK_EQ(q.size(1), k.size(2)); - CHECK_EQ(q.scalar_type(), k.scalar_type()); - CHECK_EQ(q.scalar_type(), v.scalar_type()); + CHECK_EQ(v.scalar_type(), k.scalar_type()); unsigned int num_qo_heads = q.size(0); unsigned int head_dim = q.size(1); unsigned int kv_len, num_kv_heads; @@ -51,47 +50,51 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - SingleDecodeWithKVCacheDispatched( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + SingleDecodeWithKVCacheDispatched( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); }); }); }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - SingleDecodeWithKVCacheDispatched( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + SingleDecodeWithKVCacheDispatched( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); }); }); }); From bf755f0be7967e4dfd30763d85b38b7e6e4f7b46 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 10 Jun 2024 16:06:09 -0700 Subject: [PATCH 4/9] WIP --- include/flashinfer/attention/handler.cuh | 29 ++++++++++++------------ 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 998cd44a..b4e820b7 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -30,11 +30,11 @@ namespace flashinfer { template + PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ, typename DTypeKV, + typename DTypeOut, typename IdType> __global__ void BatchDecodeWithPagedKVCacheKernel( - DTypeIn* __restrict__ q, IdType* __restrict__ q_offset, - paged_kv_t paged_kv, + DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, + paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale, float rope_rcp_theta); @@ -84,7 +84,8 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \brief Estimate the temporary buffer size and the maximum grid size for the * partition-kv BatchDecodeWithPagedKVCache kernel * \tparam page_storage Whether to store indices or pointers of each active page - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeQ A template type indicates the query data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \tparam IdType A template type indicates the index data type * \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel @@ -98,27 +99,27 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \return status Indicates whether CUDA calls are successful */ template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); constexpr uint32_t bdy = GROUP_SIZE; constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 4U) : 1U; + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; const uint32_t smem_size = - 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float)); + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel< /*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx, - bdy, bdz, page_storage, kv_layout, DTypeIn, DTypeOut, IdType>; + bdy, bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>; int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; @@ -272,7 +273,7 @@ class BatchDecodeHandler { } template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads, uint32_t page_size) { @@ -280,8 +281,8 @@ class BatchDecodeHandler { uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched; + kv_layout, POS_ENCODING_MODE, DTypeQ, + DTypeKV, DTypeOut, IdType>; FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, num_qo_heads, page_size, From 5880b4528b404472ac75a73e232df2608429c7d5 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 11 Jun 2024 14:22:26 -0700 Subject: [PATCH 5/9] WIP --- include/flashinfer/attention/decode.cuh | 2 ++ include/flashinfer/attention/handler.cuh | 10 +++++----- include/flashinfer/decode_attention_decl.cuh | 1 + python/csrc/batch_decode.cu | 12 ++++++------ python/csrc/single_decode.cu | 4 ++-- src/flashinfer_ops.cuh | 5 +++-- 6 files changed, 19 insertions(+), 15 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index fde71c2f..74446a26 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -522,6 +522,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { + static_assert(!std::is_same_v); auto block = cg::this_thread_block(); sm_scale *= math::log2e; @@ -848,6 +849,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { + static_assert(!std::is_same_v); const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; const uint32_t num_kv_heads = paged_kv.num_heads; diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index b4e820b7..eda4234b 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -84,7 +84,6 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \brief Estimate the temporary buffer size and the maximum grid size for the * partition-kv BatchDecodeWithPagedKVCache kernel * \tparam page_storage Whether to store indices or pointers of each active page - * \tparam DTypeQ A template type indicates the query data type * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \tparam IdType A template type indicates the index data type @@ -99,11 +98,12 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \return status Indicates whether CUDA calls are successful */ template + PosEncodingMode POS_ENCODING_MODE, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { + static_assert(!std::is_same_v); constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; @@ -119,7 +119,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel< /*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx, - bdy, bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>; + bdy, bdz, page_storage, kv_layout, DTypeKV, DTypeKV, DTypeOut, IdType>; int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; @@ -273,7 +273,7 @@ class BatchDecodeHandler { } template + PosEncodingMode POS_ENCODING_MODE, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads, uint32_t page_size) { @@ -281,7 +281,7 @@ class BatchDecodeHandler { uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched; FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, num_qo_heads, diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index ae9e29f0..0da8bd2e 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -56,6 +56,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { + static_assert(!std::is_same_v); paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; DTypeOut* tmp = handler->GetTempFloatBuffer(); diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 2b13ebc3..45c19440 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -57,7 +57,7 @@ std::vector batch_decode_with_padded_kv_cache( } if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { @@ -86,7 +86,7 @@ std::vector batch_decode_with_padded_kv_cache( }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { q_type* tmp = nullptr; return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { @@ -140,7 +140,7 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( handler_->SetCUDAStream(torch_current_stream); if (is_float8_tensor(empty_data)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_data.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { @@ -164,7 +164,7 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_data.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { @@ -244,7 +244,7 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( } if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(paged_kv_data.scalar_type(), kv_type, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { @@ -276,7 +276,7 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(paged_kv_data.scalar_type(), kv_type, [&] { return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index 608f1517..6f591bad 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -50,7 +50,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { @@ -75,7 +75,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(q.scalar_type(), q_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index d7f2f8ca..6f55d74d 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -214,6 +214,7 @@ cudaError_t BatchDecodeWithPagedKVCache( uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { + static_assert(!std::is_same_v); const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; const uint32_t batch_size = paged_kv.batch_size; @@ -291,7 +292,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( return cudaSuccess; } -template cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* buffer, size_t workspace_size_in_bytes, IdType* indptr, @@ -309,7 +310,7 @@ cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* bu DISPATCH_head_dim(head_dim, HEAD_DIM, { DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return handler->BeginForwardDispatched( + POS_ENCODING_MODE, DTypeKV, DTypeOut, IdType>( buffer, workspace_size_in_bytes, indptr, last_page_len, batch_size, num_qo_heads, page_size); }); From 258814db7d4bb234b206383bcba9ed6b97fae2f2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 11 Jun 2024 14:24:29 -0700 Subject: [PATCH 6/9] Update test --- python/tests/test_batch_decode_kernels.py | 56 +++++++++++++++-------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index efd75327..ff5577bd 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -30,7 +30,10 @@ @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) @pytest.mark.parametrize( - "dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] + "q_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] +) +@pytest.mark.parametrize( + "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) def test_batch_decode_with_paged_kv_cache( batch_size, @@ -41,9 +44,10 @@ def test_batch_decode_with_paged_kv_cache( head_dim, kv_layout, pos_encoding_mode, - dtype, + q_dtype, + kv_dtype, ): - q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(dtype) + q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( @@ -68,9 +72,9 @@ def test_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - dtype, + kv_dtype, ) - o = wrapper.forward(q, kv_data.to(dtype), pos_encoding_mode=pos_encoding_mode) + o = wrapper.forward(q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] @@ -90,7 +94,7 @@ def test_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ).to(dtype) + ).to(kv_dtype) vi = torch.cat( [ kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] @@ -105,7 +109,7 @@ def test_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ).to(dtype) + ).to(kv_dtype) o_ref_i = flashinfer.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode ) @@ -123,7 +127,10 @@ def test_batch_decode_with_paged_kv_cache( @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) @pytest.mark.parametrize( - "dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] + "q_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] +) +@pytest.mark.parametrize( + "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) def test_cuda_graph_batch_decode_with_paged_kv_cache( batch_size, @@ -134,9 +141,10 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, kv_layout, pos_encoding_mode, - dtype, + q_dtype, + kv_dtype, ): - q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(dtype) + q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( @@ -178,7 +186,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - dtype, + kv_dtype, ) # warmup s = torch.cuda.Stream() @@ -186,13 +194,13 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( with torch.cuda.stream(s): for _ in range(3): o = wrapper.forward( - q, kv_data.to(dtype), pos_encoding_mode=pos_encoding_mode + q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode ) torch.cuda.current_stream().wait_stream(s) # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): - o = wrapper.forward(q, kv_data.to(dtype), pos_encoding_mode=pos_encoding_mode) + o = wrapper.forward(q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode) wrapper.end_forward() # replay wrapper.begin_forward( @@ -204,7 +212,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - dtype, + kv_dtype, ) g.replay() @@ -230,7 +238,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ).to(dtype) + ).to(kv_dtype) vi = torch.cat( [ kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] @@ -245,7 +253,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ).to(dtype) + ).to(kv_dtype) o_ref_i = flashinfer.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode ) @@ -256,11 +264,21 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( - 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16 + 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16, torch.float16 ) test_batch_decode_with_paged_kv_cache( - 12, 54, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2 + 12, 54, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float8_e5m2 ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16 + 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16, torch.float16 + ) + + test_batch_decode_with_paged_kv_cache( + 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float16 ) + test_batch_decode_with_paged_kv_cache( + 12, 54, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float16 + ) + test_cuda_graph_batch_decode_with_paged_kv_cache( + 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float16 + ) \ No newline at end of file From ca02f7dce6fc8f56bb4e93d0623b10abd6273f3e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 11 Jun 2024 14:30:51 -0700 Subject: [PATCH 7/9] Cleanup --- include/flashinfer/attention/decode.cuh | 1 - include/flashinfer/attention/handler.cuh | 1 - include/flashinfer/decode_attention_decl.cuh | 1 - src/flashinfer_ops.cuh | 1 - 4 files changed, 4 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 10978c37..67a00bed 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -523,7 +523,6 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse, bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { - static_assert(!std::is_same_v); auto block = cg::this_thread_block(); sm_scale *= math::log2e; diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 1ecd0dd7..51766559 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -105,7 +105,6 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { - static_assert(!std::is_same_v); constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index 19cdb51b..40998418 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -57,7 +57,6 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { - static_assert(!std::is_same_v); paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; DTypeOut* tmp_v = handler->GetTempV(); diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index e5e69eff..7f6d0001 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -215,7 +215,6 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { - static_assert(!std::is_same_v); const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; const uint32_t batch_size = paged_kv.batch_size; From 57860afee1055199fa8866b9faf796a02aec9810 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 11 Jun 2024 14:33:02 -0700 Subject: [PATCH 8/9] Add comment --- include/flashinfer/attention/handler.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 51766559..622d3eaa 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -118,6 +118,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); + // Note that the dtype of Q should not impact the cudaOccupancyMaxActiveBlocksPerMultiprocessor + // return, which is why we just use DTypeKV as it simplifies the API. auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel< /*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx, bdy, bdz, page_storage, kv_layout, DTypeKV, DTypeKV, DTypeOut, IdType>; From c7c959d59483c69b70b065d936c1e2a501c5cb7e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 13 Jun 2024 16:09:36 -0700 Subject: [PATCH 9/9] Update BatchDecodeWithPagedKVCacheWorkEstimation --- include/flashinfer/attention/handler.cuh | 8 +-- python/csrc/batch_decode.cu | 84 ++++++++++++----------- python/csrc/flashinfer_ops.h | 2 +- python/flashinfer/decode.py | 19 ++++- python/tests/test_batch_decode_kernels.py | 4 ++ src/bench_batch_decode.cu | 2 +- src/bench_cascade.cu | 2 +- src/flashinfer_ops.cuh | 6 +- src/test_batch_decode.cu | 2 +- src/test_cascade.cu | 4 +- src/tvm_wrapper.cu | 2 +- 11 files changed, 78 insertions(+), 57 deletions(-) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 622d3eaa..6d8bf3e0 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -100,7 +100,7 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \return status Indicates whether CUDA calls are successful */ template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads, @@ -122,7 +122,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( // return, which is why we just use DTypeKV as it simplifies the API. auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel< /*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx, - bdy, bdz, page_storage, kv_layout, DTypeKV, DTypeKV, DTypeOut, IdType>; + bdy, bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>; int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; @@ -296,7 +296,7 @@ class BatchDecodeHandler { bool* GetBlockValidMask() const { return block_valid_mask_; } template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads, uint32_t page_size) { @@ -306,7 +306,7 @@ class BatchDecodeHandler { auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched; + DTypeQ, DTypeKV, DTypeOut, IdType>; FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, num_qo_heads, page_size, diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 45c19440..1af97f01 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -124,7 +124,7 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, - torch::Tensor empty_data) { + torch::Tensor empty_q_data, torch::Tensor empty_kv_data) { // NOTE(zihao): not necessary to be CUDA tensor CHECK_CONTIGUOUS(indptr); CHECK_CONTIGUOUS(last_page_len); @@ -139,50 +139,54 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); - if (is_float8_tensor(empty_data)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + if (is_float8_tensor(empty_q_data)) { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_q_data.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + handler_->BeginForwardDispatched( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); }); }); }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + handler_->BeginForwardDispatched( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); }); }); }); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 06ff2dc2..e7ab8a07 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -77,7 +77,7 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, - unsigned int pos_encoding_mode, torch::Tensor empty_data); + unsigned int pos_encoding_mode, torch::Tensor empty_q_data, torch::Tensor empty_kv_data); void EndForward(); void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 20b53c0d..3e291335 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -540,6 +540,7 @@ def begin_forward( page_size: int, pos_encoding_mode: str = "NONE", data_type: Union[str, torch.dtype] = "float16", + q_data_type: Optional[Union[str, torch.dtype]] = None, ): r"""Create auxiliary data structures for batch decode for multiple forward calls within the same decode step. @@ -566,6 +567,9 @@ def begin_forward( ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. data_type : Union[str, torch.dtype] The data type of the paged kv cache + q_data_type : Optional[Union[str, torch.dtype]] + The data type of the query tensor. If None, will be set to + ``data_type``. Note ---- @@ -599,8 +603,16 @@ def begin_forward( self._paged_kv_indices_buf = indices self._paged_kv_last_page_len_buf = last_page_len - # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info - empty_data = torch.empty( + # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info + if not q_data_type: + q_data_type = data_type + empty_q_data = torch.empty( + 0, + dtype=( + getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + ), + ) + empty_kv_data = torch.empty( 0, dtype=( getattr(torch, data_type) if isinstance(data_type, str) else data_type @@ -616,7 +628,8 @@ def begin_forward( head_dim, page_size, PosEncodingMode[pos_encoding_mode].value, - empty_data, + empty_q_data, + empty_kv_data, ) def end_forward(self): diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index 95ba3ceb..d7dc92a0 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -73,6 +73,7 @@ def test_batch_decode_with_paged_kv_cache( page_size, "NONE", kv_dtype, + q_dtype, ) o = wrapper.forward(q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode) @@ -182,6 +183,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( page_size, "NONE", kv_dtype, + q_dtype, ) # warmup s = torch.cuda.Stream() @@ -213,6 +215,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( page_size, "NONE", kv_dtype, + q_dtype, ) g.replay() @@ -233,6 +236,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( page_size, "NONE", kv_dtype, + q_dtype, ) g.replay() diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index e0b793ff..9aa1b919 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -73,7 +73,7 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); // begin forward - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index 34011d3e..ec09cdb5 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -109,7 +109,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { BatchDecodeHandler cascade_handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &cascade_handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index 7f6d0001..84294e00 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -294,8 +294,8 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( return cudaSuccess; } -template +template cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, @@ -312,7 +312,7 @@ cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* bu DISPATCH_head_dim(head_dim, HEAD_DIM, { DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return handler->BeginForwardDispatched( + POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( buffer, workspace_size_in_bytes, indptr, last_page_len, batch_size, num_qo_heads, page_size); }); diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 6b14372d..b8f77a97 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -100,7 +100,7 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si flashinfer::BatchDecodeHandler handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr.data(), kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); diff --git a/src/test_cascade.cu b/src/test_cascade.cu index e760804a..0b1e6d18 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -283,12 +283,12 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, thrust::device_vector buffer_baseline(workspace_size_in_bytes), buffer_cascade(workspace_size_in_bytes); - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &baseline_handler, (void*)thrust::raw_pointer_cast(buffer_baseline.data()), workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &cascade_handler, (void*)thrust::raw_pointer_cast(buffer_cascade.data()), workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index f3e6c6ec..b7682972 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -424,7 +424,7 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward( batch_decode_handlers[handler_idx].SetCUDAStream(static_cast(copy_stream)); DISPATCH_TVM_CUDA_IDTYPE(page_table_indptr->dtype, dtype_idx, { cudaError_t status = - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( batch_decode_handlers + handler_idx, static_cast(workspace_buffer->data), workspace_size_in_bytes, static_cast(page_table_indptr->data) +