From 52335c7caef02f5a78df418fd8825d8da289ceab Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 16 Aug 2024 06:14:44 +0000 Subject: [PATCH 01/12] upd --- include/flashinfer/attention/handler.cuh | 33 ++++++++++++++++++++++++ include/flashinfer/attention/prefill.cuh | 5 ++++ include/flashinfer/sampling.cuh | 17 +++++------- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index bcb9bc82..08d51ea4 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -556,12 +556,17 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz if (avg_packed_qo_len > 64 && head_dim < 256) { warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2) } else { +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) if (avg_packed_qo_len > 16) { warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1) } else { // avg_packed_qo_len <= 16 warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1) } +#else + // NOTE(Zihao): not enough shared memory for 1x4x1 layout + warp_layout = WarpLayout::k4x1x1; +#endif } const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout); @@ -593,6 +598,34 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); } + std::cout << kv_chunk_size << " " << new_batch_size << std::endl; + // print request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr + std::cout << "request_indices: "; + for (uint32_t i = 0; i < request_indices.size(); ++i) { + std::cout << request_indices[i] << " "; + } + std::cout << std::endl; + std::cout << "qo_tile_indices: "; + for (uint32_t i = 0; i < qo_tile_indices.size(); ++i) { + std::cout << qo_tile_indices[i] << " "; + } + std::cout << std::endl; + std::cout << "kv_tile_indices: "; + for (uint32_t i = 0; i < kv_tile_indices.size(); ++i) { + std::cout << kv_tile_indices[i] << " "; + } + std::cout << std::endl; + std::cout << "merge_indptr: "; + for (uint32_t i = 0; i < merge_indptr.size(); ++i) { + std::cout << merge_indptr[i] << " "; + } + std::cout << std::endl; + std::cout << "o_indptr: "; + for (uint32_t i = 0; i < o_indptr.size(); ++i) { + std::cout << o_indptr[i] << " "; + } + std::cout << std::endl; + // step 4: multiply kv_chunk_size by page_size kv_chunk_size *= page_size; diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 006d9753..ba6fdc32 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1936,11 +1936,16 @@ cudaError_t SinglePrefillWithKVCacheDispatched( if (unpacked_qo_len > 64 && HEAD_DIM < 256) { warp_layout = WarpLayout::k4x1x2; } else { +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) if (unpacked_qo_len > 16) { warp_layout = WarpLayout::k4x1x1; } else { warp_layout = WarpLayout::k1x4x1; } +#else + NOTE(Zihao): not enough shared memory for 1x4x1 layout + warp_layout = WarpLayout::k4x1x1; +#endif } DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 3cda7415..57cf8057 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -47,6 +47,12 @@ constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; #define FLASHINFER_CUB_SUBTRACTLEFT_DEFINED #endif +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +constexpr uint32_t BLOCK_THREADS = 1024; +#else +constexpr uint32_t BLOCK_THREADS = 512; +#endif + template struct Pair { T value; @@ -642,7 +648,6 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp template cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size, uint32_t d, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); @@ -664,7 +669,6 @@ template cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* output, IdType* row_indices, uint32_t batch_size, uint32_t d, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); @@ -686,7 +690,6 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, uint32_t max_top_k_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -712,7 +715,6 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_p_arr, uint32_t batch_size, T top_p_val, uint32_t d, uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -738,7 +740,6 @@ template cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr, IdType* output, bool* success, uint32_t batch_size, float min_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -764,7 +765,6 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k IdType* output, bool* success, uint32_t batch_size, IdType top_k_val, T top_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -1166,7 +1166,6 @@ template cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, uint32_t batch_size, float top_p_val, uint32_t d, cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1186,7 +1185,6 @@ template cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1206,7 +1204,6 @@ template cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - const uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1352,7 +1349,6 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o bool* success, IdType* row_indices, T* top_p_arr, uint32_t batch_size, uint32_t d, uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -1381,7 +1377,6 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids IdType* output_emitted_token_num, uint32_t batch_size, uint32_t num_speculative_tokens, uint32_t d, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = From b4aacf57930469aee428a389c1ab460bdfd23554 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 16 Aug 2024 06:15:55 +0000 Subject: [PATCH 02/12] upd --- include/flashinfer/attention/handler.cuh | 28 ------------------------ 1 file changed, 28 deletions(-) diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 08d51ea4..48c8126a 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -598,34 +598,6 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); } - std::cout << kv_chunk_size << " " << new_batch_size << std::endl; - // print request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr - std::cout << "request_indices: "; - for (uint32_t i = 0; i < request_indices.size(); ++i) { - std::cout << request_indices[i] << " "; - } - std::cout << std::endl; - std::cout << "qo_tile_indices: "; - for (uint32_t i = 0; i < qo_tile_indices.size(); ++i) { - std::cout << qo_tile_indices[i] << " "; - } - std::cout << std::endl; - std::cout << "kv_tile_indices: "; - for (uint32_t i = 0; i < kv_tile_indices.size(); ++i) { - std::cout << kv_tile_indices[i] << " "; - } - std::cout << std::endl; - std::cout << "merge_indptr: "; - for (uint32_t i = 0; i < merge_indptr.size(); ++i) { - std::cout << merge_indptr[i] << " "; - } - std::cout << std::endl; - std::cout << "o_indptr: "; - for (uint32_t i = 0; i < o_indptr.size(); ++i) { - std::cout << o_indptr[i] << " "; - } - std::cout << std::endl; - // step 4: multiply kv_chunk_size by page_size kv_chunk_size *= page_size; From 06359a7f1b3dfc7fead135fd7ae3d2248b24fb3c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 16 Aug 2024 06:17:43 +0000 Subject: [PATCH 03/12] upd --- include/flashinfer/attention/prefill.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index ba6fdc32..22b632eb 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1943,7 +1943,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched( warp_layout = WarpLayout::k1x4x1; } #else - NOTE(Zihao): not enough shared memory for 1x4x1 layout + # NOTE(Zihao): not enough shared memory for 1x4x1 layout warp_layout = WarpLayout::k4x1x1; #endif } From 3ba3dd5507a6aef035842e21bdec1e59b50d6bff Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 16 Aug 2024 06:18:13 +0000 Subject: [PATCH 04/12] upd --- include/flashinfer/attention/prefill.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 22b632eb..ffd9913c 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1943,7 +1943,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched( warp_layout = WarpLayout::k1x4x1; } #else - # NOTE(Zihao): not enough shared memory for 1x4x1 layout + // NOTE(Zihao): not enough shared memory for 1x4x1 layout warp_layout = WarpLayout::k4x1x1; #endif } From bf1e446b845730f7b06f6307165fbe030cd8f92c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 17 Aug 2024 06:26:13 +0000 Subject: [PATCH 05/12] upd --- include/flashinfer/sampling.cuh | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 57cf8057..6154393d 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -47,11 +47,13 @@ constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; #define FLASHINFER_CUB_SUBTRACTLEFT_DEFINED #endif +constexpr uint32_t get_block_threads() { #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) -constexpr uint32_t BLOCK_THREADS = 1024; + return 1024; #else -constexpr uint32_t BLOCK_THREADS = 512; + return 512; #endif +} template struct Pair { @@ -648,6 +650,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp template cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size, uint32_t d, bool deterministic, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(T), d); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); @@ -669,6 +672,7 @@ template cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* output, IdType* row_indices, uint32_t batch_size, uint32_t d, bool deterministic, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(T), d); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); @@ -690,6 +694,7 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, uint32_t max_top_k_rounds, bool deterministic, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -715,6 +720,7 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_p_arr, uint32_t batch_size, T top_p_val, uint32_t d, uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -740,6 +746,7 @@ template cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr, IdType* output, bool* success, uint32_t batch_size, float min_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -765,6 +772,7 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k IdType* output, bool* success, uint32_t batch_size, IdType top_k_val, T top_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -1166,6 +1174,7 @@ template cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, uint32_t batch_size, float top_p_val, uint32_t d, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1185,6 +1194,7 @@ template cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1204,6 +1214,7 @@ template cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1349,6 +1360,7 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o bool* success, IdType* row_indices, T* top_p_arr, uint32_t batch_size, uint32_t d, uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -1377,6 +1389,7 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids IdType* output_emitted_token_num, uint32_t batch_size, uint32_t num_speculative_tokens, uint32_t d, bool deterministic, cudaStream_t stream = 0) { + constexpr uint32_t BLOCK_THREADS = get_block_threads(); const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = From 4c81c1f13c89b0c464e5976c8e827e460d8d6b89 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 17 Aug 2024 06:36:48 +0000 Subject: [PATCH 06/12] upd --- include/flashinfer/attention/decode.cuh | 2 -- include/flashinfer/attention/prefill.cuh | 2 -- include/flashinfer/mma.cuh | 6 ---- include/flashinfer/sampling.cuh | 46 ++++++++++++++---------- python/csrc/pytorch_extension_utils.h | 4 --- 5 files changed, 27 insertions(+), 33 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 09ef0941..ab62d000 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -20,9 +20,7 @@ #include #include -#ifdef FLASHINFER_ENABLE_FP8 #include -#endif #include #include diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index ffd9913c..f5d42c9a 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -18,9 +18,7 @@ #include #include #include -#ifdef FLASHINFER_ENABLE_FP8 #include -#endif #include #include "../cp_async.cuh" diff --git a/include/flashinfer/mma.cuh b/include/flashinfer/mma.cuh index 82f457a5..3f6ae16d 100644 --- a/include/flashinfer/mma.cuh +++ b/include/flashinfer/mma.cuh @@ -18,9 +18,7 @@ #include #include -#ifdef FLASHINFER_ENABLE_FP8 #include -#endif #include #include @@ -206,7 +204,6 @@ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { #endif } -#ifdef FLASHINFER_ENABLE_FP8 /*! * \brief Wrapper of two mma m16n8k32 instructions for row major and column major f8 matrix * multiplication, accumulated in f32. @@ -307,7 +304,6 @@ __device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uin "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"); #endif } -#endif /*! * \brief Wrapper of two mma m16n8k16 instructions for row major and column major f16 matrix @@ -476,7 +472,6 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u #endif } -#ifdef FLASHINFER_ENABLE_FP8 /*! * \brief Use mma instructions to compute rowsum. */ @@ -515,7 +510,6 @@ __device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) { "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"); #endif } -#endif /*! * \brief Use mma instructions to compute rowsum. diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 6154393d..b8ea6387 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -47,14 +47,6 @@ constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; #define FLASHINFER_CUB_SUBTRACTLEFT_DEFINED #endif -constexpr uint32_t get_block_threads() { -#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) - return 1024; -#else - return 512; -#endif -} - template struct Pair { T value; @@ -650,7 +642,7 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp template cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output, uint32_t batch_size, uint32_t d, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); + constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); @@ -672,7 +664,7 @@ template cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples, IdType* output, IdType* row_indices, uint32_t batch_size, uint32_t d, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); + constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); dim3 nblks(batch_size); dim3 nthrs(BLOCK_THREADS); @@ -694,7 +686,11 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, uint32_t max_top_k_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) + constexpr uint32_t BLOCK_THREADS = 1024; +#else + constexpr uint32_t BLOCK_THREADS = 512; +#endif const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -720,7 +716,11 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_p_arr, uint32_t batch_size, T top_p_val, uint32_t d, uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) + constexpr uint32_t BLOCK_THREADS = 1024; +#else + constexpr uint32_t BLOCK_THREADS = 512; +#endif const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -746,7 +746,11 @@ template cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr, IdType* output, bool* success, uint32_t batch_size, float min_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) + constexpr uint32_t BLOCK_THREADS = 1024; +#else + constexpr uint32_t BLOCK_THREADS = 512; +#endif const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -772,7 +776,11 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k IdType* output, bool* success, uint32_t batch_size, IdType top_k_val, T top_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); +#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) + constexpr uint32_t BLOCK_THREADS = 1024; +#else + constexpr uint32_t BLOCK_THREADS = 512; +#endif const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -1174,7 +1182,7 @@ template cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr, uint32_t batch_size, float top_p_val, uint32_t d, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); + constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1194,7 +1202,7 @@ template cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); + constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1214,7 +1222,7 @@ template cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); + constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1360,7 +1368,7 @@ cudaError_t ParallelTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* o bool* success, IdType* row_indices, T* top_p_arr, uint32_t batch_size, uint32_t d, uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); + constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -1389,7 +1397,7 @@ cudaError_t ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids IdType* output_emitted_token_num, uint32_t batch_size, uint32_t num_speculative_tokens, uint32_t d, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = get_block_threads(); + constexpr uint32_t BLOCK_THREADS = 1024; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index a702c8ee..0c1fba76 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -22,12 +22,8 @@ #include #include "generated/dispatch.inc" -#ifdef FLASHINFER_ENABLE_BF16 #include -#endif -#ifdef FLASHINFER_ENABLE_FP8 #include -#endif using namespace flashinfer; From fbe214b0b3ca102b9405cc713a769fc39a97107a Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sat, 17 Aug 2024 09:01:41 +0000 Subject: [PATCH 07/12] upd --- include/flashinfer/sampling.cuh | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index b8ea6387..727e9ff5 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -686,11 +686,7 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, uint32_t max_top_k_rounds, bool deterministic, cudaStream_t stream = 0) { -#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) - constexpr uint32_t BLOCK_THREADS = 1024; -#else constexpr uint32_t BLOCK_THREADS = 512; -#endif const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -716,11 +712,7 @@ cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_p_arr, uint32_t batch_size, T top_p_val, uint32_t d, uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) { -#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) constexpr uint32_t BLOCK_THREADS = 1024; -#else - constexpr uint32_t BLOCK_THREADS = 512; -#endif const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -746,11 +738,7 @@ template cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr, IdType* output, bool* success, uint32_t batch_size, float min_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { -#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) constexpr uint32_t BLOCK_THREADS = 1024; -#else - constexpr uint32_t BLOCK_THREADS = 512; -#endif const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -776,11 +764,7 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k IdType* output, bool* success, uint32_t batch_size, IdType top_k_val, T top_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { -#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800) - constexpr uint32_t BLOCK_THREADS = 1024; -#else constexpr uint32_t BLOCK_THREADS = 512; -#endif const uint32_t vec_size = std::gcd(16 / sizeof(T), d); const uint32_t smem_size = sizeof(SamplingTempStorage); @@ -1202,7 +1186,7 @@ template cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; + constexpr uint32_t BLOCK_THREADS = 512; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); @@ -1222,7 +1206,7 @@ template cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 1024; + constexpr uint32_t BLOCK_THREADS = 512; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); const uint32_t smem_size = sizeof(RenormTempStorage); From f31d8af1e373efafba294d78ee555bb5d8bbfdc4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 26 Aug 2024 22:09:54 +0000 Subject: [PATCH 08/12] upd --- include/flashinfer/attention/decode.cuh | 3 +- include/flashinfer/attention/handler.cuh | 20 ++-- include/flashinfer/attention/prefill.cuh | 18 ++-- include/flashinfer/sampling.cuh | 131 +++++++++++++---------- include/flashinfer/utils.cuh | 10 ++ python/csrc/pytorch_extension_utils.h | 4 +- 6 files changed, 109 insertions(+), 77 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index ab62d000..c8a7c75d 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -18,11 +18,10 @@ #include #include #include - -#include #include #include +#include #include #include #include diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 48c8126a..3ff981d3 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -556,17 +556,19 @@ cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, uint32_t& split_max_batch_siz if (avg_packed_qo_len > 64 && head_dim < 256) { warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 2) } else { -#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) - if (avg_packed_qo_len > 16) { - warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1) + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first >= 8) { + // Ampere or newer + if (avg_packed_qo_len > 16) { + warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, num_frags_x = 1) + } else { + // avg_packed_qo_len <= 16 + warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1) + } } else { - // avg_packed_qo_len <= 16 - warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, num_frags_x = 1) + // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout + warp_layout = WarpLayout::k4x1x1; } -#else - // NOTE(Zihao): not enough shared memory for 1x4x1 layout - warp_layout = WarpLayout::k4x1x1; -#endif } const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout); diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index f5d42c9a..b7c18ef0 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1934,16 +1934,18 @@ cudaError_t SinglePrefillWithKVCacheDispatched( if (unpacked_qo_len > 64 && HEAD_DIM < 256) { warp_layout = WarpLayout::k4x1x2; } else { -#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) - if (unpacked_qo_len > 16) { - warp_layout = WarpLayout::k4x1x1; + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first >= 8) { + // Ampere or newer + if (unpacked_qo_len > 16) { + warp_layout = WarpLayout::k4x1x1; + } else { + warp_layout = WarpLayout::k1x4x1; + } } else { - warp_layout = WarpLayout::k1x4x1; + // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout + warp_layout = WarpLayout::k4x1x1; } -#else - // NOTE(Zihao): not enough shared memory for 1x4x1 layout - warp_layout = WarpLayout::k4x1x1; -#endif } DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 727e9ff5..4df2a006 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -40,6 +40,15 @@ using namespace cub; __VA_ARGS__ \ } +#define DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t BLOCK_THREADS = 1024; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t BLOCK_THREADS = 512; \ + __VA_ARGS__ \ + } + constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS; constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS; @@ -686,25 +695,28 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, uint32_t max_top_k_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 512; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &output, &success, - &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKSamplingFromProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - })}); - return cudaSuccess; + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = + sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &uniform_samples, &output, &success, + &top_k_arr, &top_k_val, &d, &max_top_k_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + })}); + return cudaSuccess; + }); } template @@ -764,25 +776,28 @@ cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples, IdType* top_k IdType* output, bool* success, uint32_t batch_size, IdType top_k_val, T top_p_val, uint32_t d, uint32_t max_rounds, bool deterministic, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 512; const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - const uint32_t smem_size = sizeof(SamplingTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &uniform_samples, &top_k_arr, &top_p_arr, &output, - &success, &top_k_val, &top_p_val, &d, &max_rounds}; - - DISPATCH_ALIGNED_VEC_SIZE( - vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { - auto kernel = TopKTopPSamplingFromProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - })}); - return cudaSuccess; + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = + sizeof(SamplingTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &uniform_samples, &top_k_arr, &top_p_arr, &output, + &success, &top_k_val, &top_p_val, &d, &max_rounds}; + + DISPATCH_ALIGNED_VEC_SIZE( + vec_size, VEC_SIZE, {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, { + auto kernel = TopKTopPSamplingFromProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + })}); + return cudaSuccess; + }); } template @@ -1186,40 +1201,44 @@ template cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 512; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKRenormProbKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopKRenormProbKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; }); - return cudaSuccess; } template cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits, IdType* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, cudaStream_t stream = 0) { - constexpr uint32_t BLOCK_THREADS = 512; const uint32_t vec_size = std::gcd(16 / sizeof(DType), d); - const uint32_t smem_size = sizeof(RenormTempStorage); - dim3 nblks(batch_size); - dim3 nthrs(BLOCK_THREADS); - void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d}; - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = TopKMaskLogitsKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_NUM_THREADS(compute_capacity, BLOCK_THREADS, { + const uint32_t smem_size = sizeof(RenormTempStorage); + dim3 nblks(batch_size); + dim3 nthrs(BLOCK_THREADS); + void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d}; + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = TopKMaskLogitsKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; }); - return cudaSuccess; } template #include #include @@ -235,6 +236,15 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { return (x + y - 1) / y; } +inline std::pair GetCudaComputeCapability() { + int device_id = 0; + cudaGetDevice(&device_id); + int major = 0, minor = 0; + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_id); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id); + return std::make_pair(major, minor); +} + template inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { std::vector host_array(size); diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 0c1fba76..d6895041 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -15,15 +15,15 @@ */ #pragma once #include +#include #include +#include #include #include #include #include "generated/dispatch.inc" -#include -#include using namespace flashinfer; From bc9b9aba7630696ddda41e170ea133dcb90c2857 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 26 Aug 2024 22:19:23 +0000 Subject: [PATCH 09/12] upd --- .github/workflows/release_wheel.yml | 2 +- docs/installation.rst | 2 +- include/flashinfer/attention/decode.cuh | 4 ---- python/setup.py | 9 +++------ 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/.github/workflows/release_wheel.yml b/.github/workflows/release_wheel.yml index 321d268d..aa9b1265 100644 --- a/.github/workflows/release_wheel.yml +++ b/.github/workflows/release_wheel.yml @@ -18,7 +18,7 @@ on: # required: true env: - TORCH_CUDA_ARCH_LIST: "8.0 8.9 9.0+PTX" + TORCH_CUDA_ARCH_LIST: "7.5 8.0 8.9 9.0+PTX" jobs: build: diff --git a/docs/installation.rst b/docs/installation.rst index 95fbf84a..266ebbdb 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -19,7 +19,7 @@ Prerequisites - Use ``python -c "import torch; print(torch.version.cuda)"`` to check your PyTorch CUDA version. -- Supported GPU architectures: ``sm80``, ``sm86``, ``sm89``, ``sm90`` (``sm75`` / ``sm70`` support is working in progress). +- Supported GPU architectures: ``sm75``, ``sm80``, ``sm86``, ``sm89``, ``sm90``. Quick Start ^^^^^^^^^^^ diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index c8a7c75d..a84620a5 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -594,11 +594,7 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo return 512U; } } else { -#ifdef FLASHINFER_ENABLE_BF16 return 128U; -#else - return 64U; -#endif } } diff --git a/python/setup.py b/python/setup.py index 2fd605be..22d2878a 100644 --- a/python/setup.py +++ b/python/setup.py @@ -32,17 +32,14 @@ root = pathlib.Path(__name__).parent -enable_bf16 = True -# NOTE(Zihao): we haven't utilized fp8 tensor cores yet, so there is no # cuda arch check for fp8 at the moment. -enable_fp8 = True for cuda_arch_flags in torch_cpp_ext._get_cuda_arch_flags(): arch = int(re.search("compute_\d+", cuda_arch_flags).group()[-2:]) if arch < 75: raise RuntimeError("FlashInfer requires sm75+") - elif arch == 75: - # disable bf16 for sm75 - enable_bf16 = False + +enable_bf16 = os.environ.get("FLASHINFER_ENABLE_BF16", "1") == "1" +enable_fp8 = os.environ.get("FLASHINFER_ENABLE_FP8", "1") == "1" if enable_bf16: torch_cpp_ext.COMMON_NVCC_FLAGS.append("-DFLASHINFER_ENABLE_BF16") From 1daa15bcbb4e982a705a6434308d353f26732f1f Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 26 Aug 2024 22:49:22 +0000 Subject: [PATCH 10/12] upd --- include/flashinfer/mma.cuh | 179 +++++++++++++++++++------------------ 1 file changed, 93 insertions(+), 86 deletions(-) diff --git a/include/flashinfer/mma.cuh b/include/flashinfer/mma.cuh index 3f6ae16d..3c54a3f1 100644 --- a/include/flashinfer/mma.cuh +++ b/include/flashinfer/mma.cuh @@ -400,72 +400,76 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u } } #elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED) - if constexpr (mma_mode == MMAMode::kInit) { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) - : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) - : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + if constexpr (std::is_same::value) { + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + } } else { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) - : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3])); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) - : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, %1, %2, %3}," - "{%4, %5}," - "{%6}," - "{%7, %8, %9, %10};\n" - : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) - : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7])); + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); } #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); @@ -545,27 +549,30 @@ __device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) { "r"(1065369472), "f"(d[0]), "f"(d[1])); } #elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED) - static_assert(std::is_same::value, "bf16 mma instruction is not supported on sm_75"); - asm volatile( - "{\n" - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, _, %1, _}," - "{%2, %3}," - "{%4}," - "{%5, 0., %6, 0.};\n" - "}\n" - : "=f"(d[0]), "=f"(d[1]) - : "r"(s_u32[0]), "r"(s_u32[1]), "r"(1006648320), "f"(d[0]), "f"(d[1])); - asm volatile( - "{\n" - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " - "{%0, _, %1, _}," - "{%2, %3}," - "{%4}," - "{%5, 0., %6, 0.};\n" - "}\n" - : "=f"(d[0]), "=f"(d[1]) - : "r"(s_u32[2]), "r"(s_u32[3]), "r"(1006648320), "f"(d[0]), "f"(d[1])); + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3}," + "{%4}," + "{%5, 0., %6, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), "r"(s_u32[1]), "r"(1006648320), "f"(d[0]), "f"(d[1])); + asm volatile( + "{\n" + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3}," + "{%4}," + "{%5, 0., %6, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[2]), "r"(s_u32[3]), "r"(1006648320), "f"(d[0]), "f"(d[1])); + } else { + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); + } #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); #endif From d6b330f522631fc2d4a34a637eacaa841fa88667 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 26 Aug 2024 23:57:03 +0000 Subject: [PATCH 11/12] upd --- include/flashinfer/attention/decode.cuh | 256 ++++++++++++----------- include/flashinfer/attention/handler.cuh | 86 ++++---- include/flashinfer/utils.cuh | 9 + 3 files changed, 186 insertions(+), 165 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index a84620a5..ec9349e2 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -515,6 +515,13 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr); } } +#pragma unroll + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + k_ptrs[j] = k_ptrs_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) * + tile_size_per_bdx + + j] + + tx * vec_size; + } // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); @@ -527,13 +534,6 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( alibi_slope, s, st, logits_soft_cap); block.sync(); -#pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - k_ptrs[j] = k_ptrs_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) * - tile_size_per_bdx + - j] + - tx * vec_size; - } // load k tiles #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { @@ -632,8 +632,8 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; + auto compute_capacity = GetCudaComputeCapability(); static_assert(bdx <= 32U); DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { constexpr uint32_t bdy = GROUP_SIZE; @@ -642,69 +642,74 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, constexpr uint32_t bdz = num_threads / (bdx * bdy); tensor_info_t info(1, seq_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; - const uint32_t smem_size = - 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + - 2U * bdy * bdz * sizeof(float); - auto kernel = SingleDecodeWithKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - if (seq_len <= 256 || tmp == nullptr) { - // no need to use partition-kv kernel - dim3 nblks = dim3(1, num_kv_heads); - dim3 nthrs = dim3(bdx, bdy, bdz); - float* lse = nullptr; - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&o, - (void*)&lse, - (void*)&info, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta, - (void*)&seq_len}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } else { - // use partition-kv kernel - int num_blocks_per_sm = 0; - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, - num_threads, smem_size)); - uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); - uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; - uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); - uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); - dim3 nblks = dim3(num_chunks, num_kv_heads); - if (nblks.x == 0 || nblks.y == 0) { - std::ostringstream err_msg; - err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")"; - throw std::runtime_error(err_msg.str()); - } - dim3 nthrs = dim3(bdx, bdy, bdz); - float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM); - void* args[] = {(void*)&q, - (void*)&k, - (void*)&v, - (void*)&tmp, - (void*)&tmp_lse, - (void*)&info, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta, - (void*)&kv_chunk_size}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + const uint32_t smem_size = + 2U * NUM_STAGES_SMEM * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + + 2U * bdy * bdz * sizeof(float); + auto kernel = SingleDecodeWithKVCacheKernel; FLASHINFER_CUDA_CALL( - MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); - } + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (seq_len <= 256 || tmp == nullptr) { + // no need to use partition-kv kernel + dim3 nblks = dim3(1, num_kv_heads); + dim3 nthrs = dim3(bdx, bdy, bdz); + float* lse = nullptr; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&o, + (void*)&lse, + (void*)&info, + (void*)&window_left, + (void*)&logits_soft_cap, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta, + (void*)&seq_len}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // use partition-kv kernel + int num_blocks_per_sm = 0; + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); + uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; + uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); + uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); + dim3 nblks = dim3(num_chunks, num_kv_heads); + if (nblks.x == 0 || nblks.y == 0) { + std::ostringstream err_msg; + err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")"; + throw std::runtime_error(err_msg.str()); + } + dim3 nthrs = dim3(bdx, bdy, bdz); + float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM); + void* args[] = {(void*)&q, + (void*)&k, + (void*)&v, + (void*)&tmp, + (void*)&tmp_lse, + (void*)&info, + (void*)&window_left, + (void*)&logits_soft_cap, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta, + (void*)&kv_chunk_size}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL( + MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); + } + }); }); return cudaSuccess; } @@ -723,7 +728,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( const uint32_t num_kv_heads = paged_kv.num_heads; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - constexpr uint32_t num_stages_smem = 2U; + auto compute_capacity = GetCudaComputeCapability(); constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { @@ -731,58 +736,63 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; - const uint32_t smem_size = - 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); - auto kernel = - BatchDecodeWithPagedKVCacheKernel; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - if (tmp_v == nullptr) { - // do not use partition-kv kernel - bool partition_kv = false; - dim3 nblks(padded_batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - - void* args[] = {(void*)&q, - (void*)&q_offset, - (void*)&paged_kv, - (void*)&kv_partition_info, - (void*)&o, - (void*)&lse, - (void*)&block_valid_mask, - (void*)&partition_kv, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - } else { - // use partition-kv kernel - bool partition_kv = true; - void* args[] = {(void*)&q, - (void*)&q_offset, - (void*)&paged_kv, - (void*)&kv_partition_info, - (void*)&tmp_v, - (void*)&tmp_s, - (void*)&block_valid_mask, - (void*)&partition_kv, - (void*)&window_left, - (void*)&logits_soft_cap, - (void*)&sm_scale, - (void*)&rope_rcp_scale, - (void*)&rope_rcp_theta}; - dim3 nblks(padded_batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse, - kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream)); - } + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + const uint32_t smem_size = + 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), + 2 * bdy * bdz * sizeof(float)); + auto kernel = + BatchDecodeWithPagedKVCacheKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (tmp_v == nullptr) { + // do not use partition-kv kernel + bool partition_kv = false; + dim3 nblks(padded_batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + + void* args[] = {(void*)&q, + (void*)&q_offset, + (void*)&paged_kv, + (void*)&kv_partition_info, + (void*)&o, + (void*)&lse, + (void*)&block_valid_mask, + (void*)&partition_kv, + (void*)&window_left, + (void*)&logits_soft_cap, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // use partition-kv kernel + bool partition_kv = true; + void* args[] = {(void*)&q, + (void*)&q_offset, + (void*)&paged_kv, + (void*)&kv_partition_info, + (void*)&tmp_v, + (void*)&tmp_s, + (void*)&block_valid_mask, + (void*)&partition_kv, + (void*)&window_left, + (void*)&logits_soft_cap, + (void*)&sm_scale, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + dim3 nblks(padded_batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, kv_partition_info.chunk_indptr, o, lse, + kv_partition_info.batch_size_before_partition, num_qo_heads, HEAD_DIM, stream)); + } + }); }); return cudaSuccess; } diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 3ff981d3..e29b99c4 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -145,51 +145,53 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - constexpr uint32_t num_stages_smem = 2U; - constexpr uint32_t bdx = HEAD_DIM / vec_size; - static_assert(bdx <= 32); - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = std::max(128U, bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; - const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; - const uint32_t smem_size = - 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); - - auto kernel = - BatchDecodeWithPagedKVCacheKernel; - int num_blocks_per_sm = 0; - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, - num_threads, smem_size)); - max_grid_size = num_blocks_per_sm * num_sm; - if (batch_size * num_kv_heads >= max_grid_size) { - split_kv = false; - new_batch_size = batch_size; - } else { - // compute max_num_pages_per_batch and new_batch_size - std::vector num_pages(batch_size); - for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; - } - std::tie(max_num_pages_per_batch, new_batch_size) = - PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, num_kv_heads, num_pages, - std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { - // do not use partition-kv kernel for short sequence, when not using CUDAGraph + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + constexpr uint32_t bdx = HEAD_DIM / vec_size; + static_assert(bdx <= 32); + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = std::max(128U, bdx * bdy); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; + const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; + const uint32_t smem_size = + 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); + + auto kernel = + BatchDecodeWithPagedKVCacheKernel; + int num_blocks_per_sm = 0; + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + max_grid_size = num_blocks_per_sm * num_sm; + if (batch_size * num_kv_heads >= max_grid_size) { split_kv = false; + new_batch_size = batch_size; } else { - // when using CUDAGraph, we always use partition-kv kernel - split_kv = true; + // compute max_num_pages_per_batch and new_batch_size + std::vector num_pages(batch_size); + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; + } + std::tie(max_num_pages_per_batch, new_batch_size) = + PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( + max_grid_size, num_kv_heads, num_pages, std::max(128 / page_size, 1U)); + if (new_batch_size == batch_size && !enable_cuda_graph) { + // do not use partition-kv kernel for short sequence, when not using CUDAGraph + split_kv = false; + } else { + // when using CUDAGraph, we always use partition-kv kernel + split_kv = true; + } } - } - return cudaSuccess; + return cudaSuccess; + }) } /*! diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 0ecd1e4c..a7a8dfd3 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -229,6 +229,15 @@ } \ } +#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t NUM_STAGES_SMEM = 2; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t NUM_STAGES_SMEM = 1; \ + __VA_ARGS__ \ + } + namespace flashinfer { template From 63f9afc1c69c3a503117119b0ade75fcbf9353a5 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 27 Aug 2024 00:05:59 +0000 Subject: [PATCH 12/12] upd --- include/flashinfer/attention/decode.cuh | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index ec9349e2..c1bf4cc7 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -515,13 +515,6 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr); } } -#pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - k_ptrs[j] = k_ptrs_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) * - tile_size_per_bdx + - j] + - tx * vec_size; - } // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); @@ -534,6 +527,14 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( alibi_slope, s, st, logits_soft_cap); block.sync(); +#pragma unroll + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + k_ptrs[j] = k_ptrs_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) * + tile_size_per_bdx + + j] + + tx * vec_size; + } + // load k tiles #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {