From a9dd989e0ed23a6ebeec671aa41e2adb9e5d446c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 19 Jul 2024 09:01:04 +0000 Subject: [PATCH 1/3] upd --- include/flashinfer/page.cuh | 42 +++++++++++ python/csrc/batch_decode.cu | 80 ++++++++++++++++----- python/csrc/batch_prefill.cu | 135 +++++++++++++++++++++++++++-------- python/csrc/flashinfer_ops.h | 21 ++++-- python/csrc/page.cu | 68 +++++++++++++----- 5 files changed, 273 insertions(+), 73 deletions(-) diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index 179edc00..986f66b3 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -115,6 +115,48 @@ struct paged_kv_t { last_page_len(nullptr), rope_pos_offset(nullptr) {} + /*! + * \brief Construct a paged key-value cache + * \param num_heads The number of heads + * \param page_size The size of each page + * \param head_dim The dimension of each head + * \param batch_size The batch size + * \param layout The layout of last 3 dimensions in KV-Cache. + * \param kv_data The flattened key-value cache + * \param k_data The flattened key cache + * \param v_data The flattened value cache + * \param indices The page indices array + * \param indptr The page indptr array + * \param last_page_len The offset of the last page for each request in the batch + * \param rope_pos_offset The start position of each request in the batch. + * \note This constructor should only be used when page_storage == kIndices + */ + __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, + uint32_t batch_size, QKVLayout layout, DType* kv_data, DType* k_data, + DType* v_data, IdType* indices, IdType* indptr, + IdType* last_page_len, IdType* rope_pos_offset = nullptr) + : num_heads(num_heads), + page_size(page_size), + head_dim(head_dim), + batch_size(batch_size), + indices(indices), + indptr(indptr), + last_page_len(last_page_len), + rope_pos_offset(rope_pos_offset) { + bool kv_defined = kv_data != nullptr; + if (kv_defined) { + stride_page = 2 * num_heads * page_size * head_dim; + k_data = kv_data; + v_data = kv_data + num_heads * page_size * head_dim; + } else { + stride_page = num_heads * page_size * head_dim; + k_data = k_data; + v_data = v_data; + } + stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim; + stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim; + } + /*! * \brief Construct a paged key-value cache * \param num_heads The number of heads diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 83079017..7607ebfc 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -105,17 +105,29 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( } std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( - torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, - torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, - unsigned int pos_encoding_mode, float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta, bool return_lse) { + torch::Tensor q, std::optional paged_kv_cache, + std::optional paged_k_cache, std::optional paged_v_cache, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, + torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, float logits_soft_cap, + float sm_scale, float rope_scale, float rope_theta, bool return_lse) { CHECK_INPUT(q); - CHECK_INPUT(paged_kv_data); + bool paged_kv_defined = paged_kv_cache.has_value(); + if (paged_kv_defined) { + CHECK_INPUT(*paged_kv_cache); + } else { + CHECK_INPUT(*paged_k_cache); + CHECK_INPUT(*paged_v_cache); + } CHECK_INPUT(paged_kv_indptr); CHECK_INPUT(paged_kv_indices); CHECK_INPUT(paged_kv_last_page_len); auto device = q.device(); - CHECK_EQ(paged_kv_data.device(), device); + if (paged_kv_defined) { + CHECK_EQ(paged_kv_cache->device(), device); + } else { + CHECK_EQ(paged_k_cache->device(), device); + CHECK_EQ(paged_v_cache->device(), device); + } CHECK_EQ(paged_kv_indices.device(), device); CHECK_EQ(paged_kv_indptr.device(), device); CHECK_EQ(paged_kv_last_page_len.device(), device); @@ -123,22 +135,41 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( CHECK_DIM(1, paged_kv_last_page_len); // (B,) CHECK_DIM(1, paged_kv_indptr); // (B+1,) CHECK_DIM(1, paged_kv_indices); // (nnz,) - // (num_max_pages, 2, H_kv, page_size, head_dim) for HND - // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD - CHECK_DIM(5, paged_kv_data); + if (paged_kv_defined) { + // (num_max_pages, 2, H_kv, page_size, head_dim) for HND + // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD + CHECK_DIM(5, *paged_kv_cache); + } else { + // (num_max_pages, H_kv, page_size, head_dim) for HND + // (num_max_pages, page_size, H_kv, head_dim) for NHD + CHECK_DIM(4, *paged_k_cache); + CHECK_DIM(4, *paged_v_cache); + } int64_t batch_size = q.size(0); int64_t num_qo_heads = q.size(1); int64_t head_dim = q.size(2); int64_t num_kv_heads, page_size; - if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_kv_data.size(2); - page_size = paged_kv_data.size(3); + if (paged_kv_defined) { + CHECK_EQ(paged_kv_cache->size(1), 2); + CHECK_EQ(paged_kv_cache->size(4), head_dim); + if (kv_layout_ == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->size(2); + page_size = paged_kv_cache->size(3); + } else { + page_size = paged_kv_cache->size(2); + num_kv_heads = paged_kv_cache->size(3); + } } else { - page_size = paged_kv_data.size(2); - num_kv_heads = paged_kv_data.size(3); + CHECK_EQ(paged_k_cache->size(3), head_dim); + CHECK_EQ(paged_v_cache->size(3), head_dim); + if (kv_layout_ == QKVLayout::kHND) { + num_kv_heads = paged_k_cache->size(1); + page_size = paged_k_cache->size(2); + } else { + page_size = paged_k_cache->size(1); + num_kv_heads = paged_k_cache->size(2); + } } - CHECK_EQ(paged_kv_data.size(1), 2); - CHECK_EQ(paged_kv_data.size(4), head_dim); CHECK_GE(paged_kv_indptr.size(0), batch_size + 1); CHECK_GE(paged_kv_last_page_len.size(0), batch_size); // TODO(Zihao): support dispatching to different data types @@ -159,7 +190,8 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone; auto q_scalar_type = q.scalar_type(); - auto kv_scalar_type = paged_kv_data.scalar_type(); + auto kv_scalar_type = + paged_kv_defined ? paged_kv_cache.scalar_type() : paged_k_cache.scalar_type(); if (q_scalar_type == kv_scalar_type) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] { @@ -169,7 +201,12 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() + : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() + : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() + : nullptr), static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); @@ -197,7 +234,12 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() + : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() + : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() + : nullptr), static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 54088682..7d6a72a7 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -58,50 +58,83 @@ void BatchPrefillWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize( } std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( - torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, + torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, + std::optional paged_k_cache, std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + bool paged_kv_defined = paged_kv_cache.has_value(); CHECK_INPUT(q); CHECK_INPUT(qo_indptr); - CHECK_INPUT(paged_kv_data); + if (paged_kv_defined) { + CHECK_INPUT(*paged_kv_cache); + } else { + CHECK_INPUT(*paged_k_cache); + CHECK_INPUT(*paged_v_cache); + } CHECK_INPUT(paged_kv_indptr); CHECK_INPUT(paged_kv_indices); CHECK_INPUT(paged_kv_last_page_len); auto device = q.device(); CHECK_EQ(device, qo_indptr.device()); - CHECK_EQ(device, paged_kv_data.device()); + if (paged_kv_defined) { + CHECK_EQ(device, paged_kv_cache->device()); + } else { + CHECK_EQ(device, paged_k_cache->device()); + CHECK_EQ(device, paged_v_cache->device()); + } CHECK_EQ(device, paged_kv_indptr.device()); CHECK_EQ(device, paged_kv_indices.device()); CHECK_EQ(device, paged_kv_last_page_len.device()); CHECK_DIM(3, q); // (nnz_qo, H_qo, D) CHECK_DIM(1, qo_indptr); // (B + 1,) - // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND - // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND - CHECK_DIM(5, paged_kv_data); + + if (paged_kv_defined) { + // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND + // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND + CHECK_DIM(5, *paged_kv_cache); + } else { + // [max_num_pages, num_kv_heads, page_size, head_dim] for HND + // [max_num_pages, page_size, num_kv_heads, head_dim] for HND + CHECK_DIM(4, *paged_k_cache); + CHECK_DIM(4, *paged_v_cache); + } + CHECK_DIM(1, paged_kv_indptr); // (B + 1,) CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) CHECK_DIM(1, paged_kv_last_page_len); // (B,) - CHECK_EQ(q.scalar_type(), paged_kv_data.scalar_type()); + CHECK_EQ(q.scalar_type(), paged_kv_cache.scalar_type()); int64_t batch_size = qo_indptr.size(0) - 1; int64_t nnz_qo = q.size(0); int64_t num_qo_heads = q.size(1); int64_t head_dim = q.size(2); int64_t num_kv_heads, page_size; - if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_kv_data.size(2); - page_size = paged_kv_data.size(3); + + if (paged_kv_defined) { + CHECK_EQ(paged_kv_cache->size(1), 2); + CHECK_EQ(paged_kv_cache->size(4), head_dim); + if (kv_layout_ == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->size(2); + page_size = paged_kv_cache->size(3); + } else { + page_size = paged_kv_cache->size(2); + num_kv_heads = paged_kv_cache->size(3); + } } else { - page_size = paged_kv_data.size(2); - num_kv_heads = paged_kv_data.size(3); + CHECK_EQ(paged_kv_cache->size(3), head_dim); + if (kv_layout_ == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->size(1); + page_size = paged_kv_cache->size(2); + } else { + page_size = paged_kv_cache->size(1); + num_kv_heads = paged_kv_cache->size(2); + } } CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); CHECK_GE(qo_indptr.size(0), batch_size + 1); CHECK_GE(paged_kv_indptr.size(0), batch_size + 1); CHECK_GE(paged_kv_last_page_len.size(0), batch_size); - CHECK_EQ(paged_kv_data.size(1), 2); - CHECK_EQ(paged_kv_data.size(4), head_dim); qo_indptr = qo_indptr.to(torch::kInt32); paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); paged_kv_indices = paged_kv_indices.to(torch::kInt32); @@ -122,7 +155,9 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); @@ -162,14 +197,21 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( } std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCustomMask( - torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, + torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, + std::optional paged_k_cache, std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, torch::Tensor custom_mask, torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse) { + bool paged_kv_defined = paged_kv_cache.has_value(); CHECK_INPUT(q); CHECK_INPUT(qo_indptr); - CHECK_INPUT(paged_kv_data); + if (paged_kv_defined) { + CHECK_INPUT(*paged_kv_cache); + } else { + CHECK_INPUT(*paged_k_cache); + CHECK_INPUT(*paged_v_cache); + } CHECK_INPUT(paged_kv_indptr); CHECK_INPUT(paged_kv_indices); CHECK_INPUT(paged_kv_last_page_len); @@ -177,7 +219,12 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu CHECK_INPUT(qk_indptr); auto device = q.device(); CHECK_EQ(device, qo_indptr.device()); - CHECK_EQ(device, paged_kv_data.device()); + if (paged_kv_defined) { + CHECK_EQ(device, paged_kv_cache->device()); + } else { + CHECK_EQ(device, paged_k_cache->device()); + CHECK_EQ(device, paged_v_cache->device()); + } CHECK_EQ(device, paged_kv_indptr.device()); CHECK_EQ(device, paged_kv_indices.device()); CHECK_EQ(device, paged_kv_last_page_len.device()); @@ -185,33 +232,59 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu CHECK_EQ(device, qk_indptr.device()); CHECK_DIM(3, q); // (nnz_qo, H_qo, D) CHECK_DIM(1, qo_indptr); // (B + 1,) - // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND - // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND - CHECK_DIM(5, paged_kv_data); + + if (paged_kv_defined) { + // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND + // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND + CHECK_DIM(5, *paged_kv_cache); + } else { + // [max_num_pages, num_kv_heads, page_size, head_dim] for HND + // [max_num_pages, page_size, num_kv_heads, head_dim] for HND + CHECK_DIM(4, *paged_k_cache); + CHECK_DIM(4, *paged_v_cache); + } CHECK_DIM(1, paged_kv_indptr); // (B + 1,) CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) CHECK_DIM(1, paged_kv_last_page_len); // (B,) CHECK_DIM(1, custom_mask); // (nnz_qk,) CHECK_DIM(1, qk_indptr); // (B + 1,) - CHECK_EQ(q.scalar_type(), paged_kv_data.scalar_type()); + if (paged_kv_defined) { + CHECK_EQ(q.scalar_type(), paged_kv_cache->scalar_type()); + } else { + CHECK_EQ(q.scalar_type(), paged_k_cache->scalar_type()); + CHECK_EQ(q.scalar_type(), paged_v_cache->scalar_type()); + } int64_t batch_size = qo_indptr.size(0) - 1; int64_t nnz_qo = q.size(0); int64_t num_qo_heads = q.size(1); int64_t head_dim = q.size(2); int64_t num_kv_heads, page_size; - if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_kv_data.size(2); - page_size = paged_kv_data.size(3); + + if (paged_kv_defined) { + CHECK_EQ(paged_kv_cache->size(1), 2); + CHECK_EQ(paged_kv_cache->size(4), head_dim); + if (kv_layout_ == QKVLayout::kHND) { + num_kv_heads = paged_kv_cache->size(2); + page_size = paged_kv_cache->size(3); + } else { + page_size = paged_kv_cache->size(2); + num_kv_heads = paged_kv_cache->size(3); + } } else { - page_size = paged_kv_data.size(2); - num_kv_heads = paged_kv_data.size(3); + CHECK_EQ(paged_k_cache->size(3), head_dim); + CHECK_EQ(paged_v_cache->size(3), head_dim); + if (kv_layout_ == QKVLayout::kHND) { + num_kv_heads = paged_k_cache->size(1); + page_size = paged_k_cache->size(2); + } else { + page_size = paged_k_cache->size(1); + num_kv_heads = paged_k_cache->size(2); + } } CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); CHECK_GE(qo_indptr.size(0), batch_size + 1); CHECK_GE(paged_kv_indptr.size(0), batch_size + 1); CHECK_GE(paged_kv_last_page_len.size(0), batch_size); - CHECK_EQ(paged_kv_data.size(1), 2); - CHECK_EQ(paged_kv_data.size(4), head_dim); CHECK_GE(qk_indptr.size(0), batch_size + 1); qo_indptr = qo_indptr.to(torch::kInt32); paged_kv_indptr = paged_kv_indptr.to(torch::kInt32); @@ -234,7 +307,9 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { paged_kv_t paged_kv( num_kv_heads, page_size, head_dim, batch_size, kv_layout_, - static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 18559208..d837528f 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -38,9 +38,11 @@ std::vector single_prefill_with_kv_cache_custom_mask( float rope_theta, bool return_lse); void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, torch::Tensor kv_data, - torch::Tensor kv_indices, torch::Tensor kv_indptr, - torch::Tensor kv_last_page_len, unsigned int layout); + torch::Tensor append_indptr, std::optional paged_kv_cache, + std::optional paged_k_cache, + std::optional paged_v_cache, torch::Tensor kv_indices, + torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, + unsigned int layout); std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, torch::Tensor s_b); @@ -88,7 +90,9 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { void EndForward(); void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } - std::vector Forward(torch::Tensor q, torch::Tensor paged_kv_data, + std::vector Forward(torch::Tensor q, std::optional paged_kv_cache, + std::optional paged_k_cache, + std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, float logits_soft_cap, @@ -118,14 +122,17 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, - torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, - torch::Tensor paged_kv_indices, + std::optional paged_kv_cache, + std::optional paged_k_cache, + std::optional paged_v_cache, + torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, bool causal, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta, bool return_lse); std::vector ForwardCustomMask( - torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor paged_kv_data, + torch::Tensor q, torch::Tensor qo_indptr, std::optional paged_kv_cache, + std::optional paged_k_cache, std::optional paged_v_cache, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, torch::Tensor packed_custom_mask, torch::Tensor qk_indptr, unsigned int pos_encoding_mode, bool allow_fp16_qk_reduction, diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 6f593fc0..0f494390 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -21,20 +21,33 @@ using namespace flashinfer; void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, torch::Tensor kv_data, - torch::Tensor kv_indices, torch::Tensor kv_indptr, - torch::Tensor kv_last_page_len, unsigned int layout) { + torch::Tensor append_indptr, std::optional paged_kv_cache, + std::optional paged_k_cache, + std::optional paged_v_cache, torch::Tensor kv_indices, + torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, + unsigned int layout) { + bool paged_kv_defined = paged_kv_cache.has_value(); CHECK_INPUT(append_key); CHECK_INPUT(append_value); CHECK_INPUT(append_indptr); - CHECK_INPUT(kv_data); + if (paged_kv_defined) { + CHECK_INPUT(*paged_kv_cache); + } else { + CHECK_INPUT(*paged_k_cache); + CHECK_INPUT(*paged_v_cache); + } CHECK_INPUT(kv_indices); CHECK_INPUT(kv_indptr); CHECK_INPUT(kv_last_page_len); CHECK_DIM(3, append_key); CHECK_DIM(3, append_value); CHECK_DIM(1, append_indptr); - CHECK_DIM(5, kv_data); + if (paged_kv_defined) { + CHECK_DIM(5, *paged_kv_cache); + } else { + CHECK_DIM(4, *paged_k_cache); + CHECK_DIM(4, *paged_v_cache); + } CHECK_DIM(1, kv_indices); CHECK_DIM(1, kv_indptr); CHECK_DIM(1, kv_last_page_len); @@ -48,7 +61,12 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, auto device = append_indptr.device(); CHECK_EQ(append_key.device(), device); CHECK_EQ(append_value.device(), device); - CHECK_EQ(kv_data.device(), device); + if (paged_kv_defined) { + CHECK_EQ(paged_kv_cache->device(), device); + } else { + CHECK_EQ(paged_k_cache->device(), device); + CHECK_EQ(paged_v_cache->device(), device); + } CHECK_EQ(kv_indices.device(), device); CHECK_EQ(kv_indptr.device(), device); CHECK_EQ(kv_last_page_len.device(), device); @@ -57,14 +75,24 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, QKVLayout kv_layout = QKVLayout(layout); unsigned int num_heads, page_size, head_dim; - if (kv_layout == QKVLayout::kHND) { - num_heads = kv_data.size(2); - page_size = kv_data.size(3); - head_dim = kv_data.size(4); + if (paged_kv_defined) { + head_dim = paged_kv_cache->size(4); + if (kv_layout == QKVLayout::kHND) { + num_heads = paged_kv_cache->size(2); + page_size = paged_kv_cache->size(3); + } else { + page_size = paged_kv_cache->size(2); + num_heads = paged_kv_cache->size(3); + } } else { - page_size = kv_data.size(2); - num_heads = kv_data.size(3); - head_dim = kv_data.size(4); + head_dim = paged_k_cache->size(3); + if (kv_layout == QKVLayout::kHND) { + num_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); + } else { + page_size = paged_k_cache.size(1); + num_heads = paged_k_cache.size(2); + } } CHECK_EQ(append_key.size(1), num_heads); @@ -74,11 +102,16 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); - bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_data.scalar_type(), c_type, [&] { + auto kv_scalar_dtype = + paged_kv_cache.has_value() ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE(kv_scalar_dtype, c_type, [&] { paged_kv_t paged_kv( num_heads, page_size, head_dim, batch_size, kv_layout, - static_cast(kv_data.data_ptr()), static_cast(kv_indices.data_ptr()), - static_cast(kv_indptr.data_ptr()), + static_cast(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr() : nullptr), + static_cast(paged_k_cache.has_value() ? paged_k_cache->data_ptr() : nullptr), + static_cast(paged_v_cache.has_value() ? paged_v_cache->data_ptr() : nullptr), + static_cast(kv_indices.data_ptr()), static_cast(kv_indptr.data_ptr()), static_cast(kv_last_page_len.data_ptr())); cudaError_t status = AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), @@ -89,5 +122,6 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, return true; }); - TORCH_CHECK(success, "AppendPagedKVCache failed to dispatch with dtype ", kv_data.scalar_type()); + TORCH_CHECK(success, "AppendPagedKVCache failed to dispatch with dtype ", + paged_kv_cache.scalar_type()); } From d1460fc68ac0c6c9362d8ab601319c05b34afd3f Mon Sep 17 00:00:00 2001 From: yzh119 Date: Fri, 19 Jul 2024 21:05:08 +0000 Subject: [PATCH 2/3] upd --- python/csrc/batch_decode.cu | 16 ++++----- python/csrc/batch_prefill.cu | 28 ++++++++------- python/csrc/page.cu | 22 ++++++------ python/flashinfer/cascade.py | 7 ---- python/flashinfer/decode.py | 45 +++++++++++------------- python/flashinfer/page.py | 14 ++++---- python/flashinfer/prefill.py | 68 ++++++++++++++---------------------- python/flashinfer/sparse.py | 29 ++++++--------- python/flashinfer/utils.py | 36 +++++++++++++++++-- 9 files changed, 131 insertions(+), 134 deletions(-) diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index 7607ebfc..b23ae49b 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -113,10 +113,10 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( CHECK_INPUT(q); bool paged_kv_defined = paged_kv_cache.has_value(); if (paged_kv_defined) { - CHECK_INPUT(*paged_kv_cache); + CHECK_INPUT(paged_kv_cache.value()); } else { - CHECK_INPUT(*paged_k_cache); - CHECK_INPUT(*paged_v_cache); + CHECK_INPUT(paged_k_cache.value()); + CHECK_INPUT(paged_v_cache.value()); } CHECK_INPUT(paged_kv_indptr); CHECK_INPUT(paged_kv_indices); @@ -138,12 +138,12 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( if (paged_kv_defined) { // (num_max_pages, 2, H_kv, page_size, head_dim) for HND // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD - CHECK_DIM(5, *paged_kv_cache); + CHECK_DIM(5, paged_kv_cache.value()); } else { // (num_max_pages, H_kv, page_size, head_dim) for HND // (num_max_pages, page_size, H_kv, head_dim) for NHD - CHECK_DIM(4, *paged_k_cache); - CHECK_DIM(4, *paged_v_cache); + CHECK_DIM(4, paged_k_cache.value()); + CHECK_DIM(4, paged_v_cache.value()); } int64_t batch_size = q.size(0); int64_t num_qo_heads = q.size(1); @@ -191,8 +191,8 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( auto q_scalar_type = q.scalar_type(); auto kv_scalar_type = - paged_kv_defined ? paged_kv_cache.scalar_type() : paged_k_cache.scalar_type(); - + paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); + if (q_scalar_type == kv_scalar_type) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index 7d6a72a7..d07c2aa2 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -68,10 +68,10 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( CHECK_INPUT(q); CHECK_INPUT(qo_indptr); if (paged_kv_defined) { - CHECK_INPUT(*paged_kv_cache); + CHECK_INPUT(paged_kv_cache.value()); } else { - CHECK_INPUT(*paged_k_cache); - CHECK_INPUT(*paged_v_cache); + CHECK_INPUT(paged_k_cache.value()); + CHECK_INPUT(paged_v_cache.value()); } CHECK_INPUT(paged_kv_indptr); CHECK_INPUT(paged_kv_indices); @@ -93,18 +93,20 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( if (paged_kv_defined) { // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND - CHECK_DIM(5, *paged_kv_cache); + CHECK_DIM(5, paged_kv_cache.value()); + CHECK_EQ(q.scalar_type(), paged_kv_cache->scalar_type()); } else { // [max_num_pages, num_kv_heads, page_size, head_dim] for HND // [max_num_pages, page_size, num_kv_heads, head_dim] for HND - CHECK_DIM(4, *paged_k_cache); - CHECK_DIM(4, *paged_v_cache); + CHECK_DIM(4, paged_k_cache.value()); + CHECK_DIM(4, paged_v_cache.value()); + CHECK_EQ(q.scalar_type(), paged_k_cache->scalar_type()); + CHECK_EQ(q.scalar_type(), paged_v_cache->scalar_type()); } CHECK_DIM(1, paged_kv_indptr); // (B + 1,) CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) CHECK_DIM(1, paged_kv_last_page_len); // (B,) - CHECK_EQ(q.scalar_type(), paged_kv_cache.scalar_type()); int64_t batch_size = qo_indptr.size(0) - 1; int64_t nnz_qo = q.size(0); int64_t num_qo_heads = q.size(1); @@ -207,10 +209,10 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu CHECK_INPUT(q); CHECK_INPUT(qo_indptr); if (paged_kv_defined) { - CHECK_INPUT(*paged_kv_cache); + CHECK_INPUT(paged_kv_cache.value()); } else { - CHECK_INPUT(*paged_k_cache); - CHECK_INPUT(*paged_v_cache); + CHECK_INPUT(paged_k_cache.value()); + CHECK_INPUT(paged_v_cache.value()); } CHECK_INPUT(paged_kv_indptr); CHECK_INPUT(paged_kv_indices); @@ -236,12 +238,12 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::ForwardCu if (paged_kv_defined) { // [max_num_pages, 2, num_kv_heads, page_size, head_dim] for HND // [max_num_pages, 2, page_size, num_kv_heads, head_dim] for HND - CHECK_DIM(5, *paged_kv_cache); + CHECK_DIM(5, paged_kv_cache.value()); } else { // [max_num_pages, num_kv_heads, page_size, head_dim] for HND // [max_num_pages, page_size, num_kv_heads, head_dim] for HND - CHECK_DIM(4, *paged_k_cache); - CHECK_DIM(4, *paged_v_cache); + CHECK_DIM(4, paged_k_cache.value()); + CHECK_DIM(4, paged_v_cache.value()); } CHECK_DIM(1, paged_kv_indptr); // (B + 1,) CHECK_DIM(1, paged_kv_indices); // (nnz_kv,) diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 0f494390..586c1c98 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -31,10 +31,10 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, CHECK_INPUT(append_value); CHECK_INPUT(append_indptr); if (paged_kv_defined) { - CHECK_INPUT(*paged_kv_cache); + CHECK_INPUT(paged_kv_cache.value()); } else { - CHECK_INPUT(*paged_k_cache); - CHECK_INPUT(*paged_v_cache); + CHECK_INPUT(paged_k_cache.value()); + CHECK_INPUT(paged_v_cache.value()); } CHECK_INPUT(kv_indices); CHECK_INPUT(kv_indptr); @@ -43,10 +43,10 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, CHECK_DIM(3, append_value); CHECK_DIM(1, append_indptr); if (paged_kv_defined) { - CHECK_DIM(5, *paged_kv_cache); + CHECK_DIM(5, paged_kv_cache.value()); } else { - CHECK_DIM(4, *paged_k_cache); - CHECK_DIM(4, *paged_v_cache); + CHECK_DIM(4, paged_k_cache.value()); + CHECK_DIM(4, paged_v_cache.value()); } CHECK_DIM(1, kv_indices); CHECK_DIM(1, kv_indptr); @@ -87,11 +87,11 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, } else { head_dim = paged_k_cache->size(3); if (kv_layout == QKVLayout::kHND) { - num_heads = paged_k_cache.size(1); - page_size = paged_k_cache.size(2); + num_heads = paged_k_cache->size(1); + page_size = paged_k_cache->size(2); } else { - page_size = paged_k_cache.size(1); - num_heads = paged_k_cache.size(2); + page_size = paged_k_cache->size(1); + num_heads = paged_k_cache->size(2); } } @@ -123,5 +123,5 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, }); TORCH_CHECK(success, "AppendPagedKVCache failed to dispatch with dtype ", - paged_kv_cache.scalar_type()); + kv_scalar_dtype); } diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 55feadba..58d27ecb 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -37,13 +37,6 @@ single_prefill_with_kv_cache_return_lse, BatchPrefillWithPagedKVCacheWrapper, ) -from .utils import ( - expand_5d, - check_pos_encoding_mode, - check_kv_layout, - PosEncodingMode, - TensorLayout, -) def merge_state( diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 2d6ee5cc..abe78c06 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -34,9 +34,9 @@ from .utils import ( PosEncodingMode, TensorLayout, - expand_5d, - check_pos_encoding_mode, - check_kv_layout, + _check_pos_encoding_mode, + _check_kv_layout, + _unpack_paged_kv_cache, ) _cache_buf = {} @@ -139,8 +139,8 @@ def single_decode_with_kv_cache( not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ - check_pos_encoding_mode(pos_encoding_mode) - check_kv_layout(kv_layout) + _check_pos_encoding_mode(pos_encoding_mode) + _check_kv_layout(kv_layout) tmp = _get_cache_buf("single_decode_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) if logits_soft_cap is None: logits_soft_cap = 0.0 @@ -228,7 +228,7 @@ class BatchDecodeWithPagedKVCacheWrapper: >>> kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) - >>> kv_data_at_layer = [ + >>> kv_cache_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) @@ -248,9 +248,9 @@ class BatchDecodeWithPagedKVCacheWrapper: >>> outputs = [] >>> for i in range(num_layers): ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") - ... kv_data = kv_data_at_layer[i] + ... kv_cache = kv_cache_at_layer[i] ... # compute batch decode attention, reuse auxiliary data structures for all layers - ... o = decode_wrapper.forward(q, kv_data) + ... o = decode_wrapper.forward(q, kv_cache) ... outputs.append(o) ... >>> # clear auxiliary data structures @@ -313,7 +313,7 @@ def __init__( size of the buffer should be ``[batch_size]``. Only needed when ``use_cuda_graph`` is ``True``. """ - check_kv_layout(kv_layout) + _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer @@ -476,7 +476,7 @@ def begin_forward( else q_data_type ), ) - empty_kv_data = torch.empty( + empty_kv_cache = torch.empty( 0, dtype=( getattr(torch, data_type) if isinstance(data_type, str) else data_type @@ -523,7 +523,7 @@ def begin_forward( PosEncodingMode[pos_encoding_mode].value, logits_soft_cap, empty_q_data, - empty_kv_data, + empty_kv_cache, ) def end_forward(self): @@ -538,7 +538,7 @@ def end_forward(self): def forward( self, q: torch.Tensor, - paged_kv_data: torch.Tensor, + paged_kv_cache: torch.Tensor, pos_encoding_mode: str = "NONE", q_scale: Optional[float] = None, k_scale: Optional[float] = None, @@ -554,7 +554,7 @@ def forward( ---------- q : torch.Tensor The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` - paged_kv_data : torch.Tensor + paged_kv_cache : torch.Tensor A 5-D tensor of the reserved paged kv-cache data, shape: ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or @@ -589,7 +589,7 @@ def forward( torch.Tensor The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. """ - check_pos_encoding_mode(pos_encoding_mode) + _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -604,13 +604,11 @@ def forward( if rope_theta is None: rope_theta = 1e4 - paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) - if self.use_tensor_cores: out = self._wrapper.forward( q, self._qo_indptr_buf, - paged_kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, @@ -626,7 +624,7 @@ def forward( else: out = self._wrapper.forward( q, - paged_kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, @@ -644,7 +642,7 @@ def forward( def forward_return_lse( self, q: torch.Tensor, - paged_kv_data: torch.Tensor, + paged_kv_cache: torch.Tensor, pos_encoding_mode: str = "NONE", q_scale: Optional[float] = None, k_scale: Optional[float] = None, @@ -661,7 +659,7 @@ def forward_return_lse( ---------- q : torch.Tensor The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` - paged_kv_data : torch.Tensor + paged_kv_cache : torch.Tensor A 5-D tensor of the reserved paged kv-cache data, shape: ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or @@ -703,7 +701,7 @@ def forward_return_lse( Please refer to the :ref:`tutorial ` for a detailed explanation of the log-sum-exp function and attention states. """ - check_pos_encoding_mode(pos_encoding_mode) + _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -717,12 +715,11 @@ def forward_return_lse( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 - paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) if self.use_tensor_cores: V, s = self._wrapper.forward( q, self._qo_indptr_buf, - paged_kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, @@ -738,7 +735,7 @@ def forward_return_lse( else: V, s = self._wrapper.forward( q, - paged_kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index 6c7c7558..236d534d 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -28,14 +28,14 @@ else: raise e -from .utils import check_kv_layout, TensorLayout +from .utils import _check_kv_layout, TensorLayout, _unpack_paged_kv_cache def append_paged_kv_cache( append_key: torch.Tensor, append_value: torch.Tensor, append_indptr: torch.Tensor, - kv_data: torch.Tensor, + paged_kv_cache: torch.Tensor, kv_indices: torch.Tensor, kv_indptr: torch.Tensor, kv_last_page_len: torch.Tensor, @@ -53,7 +53,7 @@ def append_paged_kv_cache( ``[append_indptr[-1], num_kv_heads, head_dim]``. append_indptr : torch.Tensor The indptr tensor of the key-value pairs to append, shape: ``[batch_size + 1]``. - kv_data : torch.Tensor + paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] The 5-D tensor of the paged key-value cache, shape: ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or @@ -85,7 +85,7 @@ def append_paged_kv_cache( ... ).int() >>> max_num_pages = 1000 >>> page_size = 16 - >>> kv_data = torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0) + >>> paged_kv_cache = torch.randn(max_num_pages, 2, page_size, num_kv_heads, head_dim).half().to(0) >>> num_pages_per_req = torch.tensor([3, 1, 2, 2], dtype=torch.int32, device="cuda:0") >>> kv_page_indptr = torch.cat( ... [torch.zeros(1).int().to(0), torch.cumsum(num_pages_per_req, dim=0)] @@ -102,7 +102,7 @@ def append_paged_kv_cache( ... k_append, ... v_append, ... kv_append_indptr, - ... kv_data, + ... paged_kv_cache, ... kv_page_indices, ... kv_page_indptr, ... kv_last_page_len @@ -117,12 +117,12 @@ def append_paged_kv_cache( which means :attr:`kv_indices`, :attr:`kv_indptr`, :attr:`kv_last_page_len` has incorporated appended k/v. """ - check_kv_layout(kv_layout) + _check_kv_layout(kv_layout) _kernels.append_paged_kv_cache( append_key, append_value, append_indptr, - kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, kv_layout), kv_indices, kv_indptr, kv_last_page_len, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index dc7f2057..07c7282b 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -34,9 +34,9 @@ from .utils import ( PosEncodingMode, TensorLayout, - expand_5d, - check_pos_encoding_mode, - check_kv_layout, + _check_pos_encoding_mode, + _check_kv_layout, + _unpack_paged_kv_cache, is_float8, ) from .quantization import packbits, segment_packbits @@ -165,8 +165,8 @@ def single_prefill_with_kv_cache( not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ - check_pos_encoding_mode(pos_encoding_mode) - check_kv_layout(kv_layout) + _check_pos_encoding_mode(pos_encoding_mode) + _check_kv_layout(kv_layout) tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 32 * 1024 * 1024, q.device) if logits_soft_cap is None: logits_soft_cap = 0.0 @@ -334,8 +334,8 @@ def single_prefill_with_kv_cache_return_lse( not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ - check_pos_encoding_mode(pos_encoding_mode) - check_kv_layout(kv_layout) + _check_pos_encoding_mode(pos_encoding_mode) + _check_kv_layout(kv_layout) tmp = _get_cache_buf( "single_prefill_with_kv_cache_return_lse_tmp", 8 * 1024 * 1024, q.device ) @@ -452,7 +452,7 @@ class BatchPrefillWithPagedKVCacheWrapper: ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) >>> q_at_layer = torch.randn(num_layers, nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") - >>> kv_data_at_layer = torch.randn( + >>> kv_cache_at_layer = torch.randn( ... num_layers, max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) >>> # create auxiliary data structures for batch prefill attention @@ -469,10 +469,10 @@ class BatchPrefillWithPagedKVCacheWrapper: >>> outputs = [] >>> for i in range(num_layers): ... q = q_at_layer[i] - ... kv_data = kv_data_at_layer[i] + ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.forward( - ... q, kv_data, causal=True + ... q, kv_cache, causal=True ... ) ... outputs.append(o) ... @@ -507,10 +507,10 @@ class BatchPrefillWithPagedKVCacheWrapper: >>> outputs_custom_mask = [] >>> for i in range(num_layers): ... q = q_at_layer[i] - ... kv_data = kv_data_at_layer[i] + ... kv_cache = kv_cache_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o_custom = prefill_wrapper.forward( - ... q, kv_data + ... q, kv_cache ... ) ... assert torch.allclose(o_custom, outputs[i], rtol=1e-3, atol=1e-3) ... @@ -588,7 +588,7 @@ def __init__( This argument is only effective when ``use_cuda_graph`` is ``True`` and the custom mask will be used in attention computation. """ - check_kv_layout(kv_layout) + _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( @@ -801,7 +801,7 @@ def end_forward(self): def forward( self, q: torch.Tensor, - paged_kv_data: torch.Tensor, + paged_kv_cache: torch.Tensor, causal: bool = True, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, @@ -816,7 +816,7 @@ def forward( ---------- q : torch.Tensor The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` - paged_kv_data : torch.Tensor + paged_kv_cache : torch.Tensor A 5-D tensor of the reserved paged kv-cache data, shape: ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or @@ -853,7 +853,7 @@ def forward( torch.Tensor The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ - check_pos_encoding_mode(pos_encoding_mode) + _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -862,20 +862,12 @@ def forward( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 - if is_float8(q): - logging.warning( - "Our current prefill kernel implementation needs f16 input, the f8 inputs " - " are casted to f16, which could result in performance degradation." - ) - q = q.to(torch.float16) - paged_kv_data = paged_kv_data.to(torch.float16) - paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) if self._custom_mask_buf is None: return self._wrapper.forward( q, self._qo_indptr_buf, - paged_kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, @@ -892,7 +884,7 @@ def forward( return self._wrapper.forward_custom_mask( q, self._qo_indptr_buf, - paged_kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, @@ -910,7 +902,7 @@ def forward( def forward_return_lse( self, q: torch.Tensor, - paged_kv_data: torch.Tensor, + paged_kv_cache: torch.Tensor, causal: bool = True, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, @@ -925,7 +917,7 @@ def forward_return_lse( ---------- q : torch.Tensor The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` - paged_kv_data : torch.Tensor + paged_kv_cache : torch.Tensor A 5-D tensor of the reserved paged kv-cache data, shape: ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or @@ -963,7 +955,7 @@ def forward_return_lse( The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ - check_pos_encoding_mode(pos_encoding_mode) + _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -972,20 +964,12 @@ def forward_return_lse( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 - if is_float8(q): - logging.warning( - "Our current prefill kernel implementation needs f16 input, the f8 inputs " - " are casted to f16, which could result in performance degradation." - ) - q = q.to(torch.float16) - paged_kv_data = paged_kv_data.to(torch.float16) - paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) if self._custom_mask_buf is None: return self._wrapper.forward( q, self._qo_indptr_buf, - paged_kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, @@ -1002,7 +986,7 @@ def forward_return_lse( return self._wrapper.forward( q, self._qo_indptr_buf, - paged_kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len_buf, @@ -1172,7 +1156,7 @@ def __init__( This argument is only effective when ``use_cuda_graph`` is ``True`` and custom mask will be used in attention computation. """ - check_kv_layout(kv_layout) + _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchPrefillWithRaggedKVCachePyTorchWrapper( @@ -1399,7 +1383,7 @@ def forward( torch.Tensor The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ - check_pos_encoding_mode(pos_encoding_mode) + _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -1506,7 +1490,7 @@ def forward_return_lse( The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ - check_pos_encoding_mode(pos_encoding_mode) + _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py index df92f2ab..3c18e94c 100644 --- a/python/flashinfer/sparse.py +++ b/python/flashinfer/sparse.py @@ -21,10 +21,10 @@ from .prefill import _compute_page_qk_indptr from .quantization import segment_packbits from .utils import ( - check_pos_encoding_mode, - check_kv_layout, + _check_pos_encoding_mode, + _check_kv_layout, + _unpack_paged_kv_cache, is_float8, - expand_5d, PosEncodingMode, TensorLayout, ) @@ -62,7 +62,7 @@ def __init__( kv_layout : str The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. """ - check_kv_layout(kv_layout) + _check_kv_layout(kv_layout) self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( @@ -190,7 +190,7 @@ def end_forward(self): def forward( self, q: torch.Tensor, - kv_data: torch.Tensor, + paged_kv_cache: torch.Tensor, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, @@ -200,14 +200,14 @@ def forward( ): r"""Compute block-sparse attention between Q/K/V tensors. - Warning(Zihao): in the next release, kv_data will be decoupled into standalone k/v tensors, each + Warning(Zihao): in the next release, paged_kv_cache will be decoupled into standalone k/v tensors, each with shape (N, num_kv_heads, head_dim). Parameters ---------- q : torch.Tensor The query tensor, shape (M, num_qo_heads, head_dim). - kv_data : torch.Tensor + paged_kv_cache : torch.Tensor The key/value tensor, shape (N // C, 2, C, num_kv_heads, head_dim). pos_encoding_mode : str, optional The position encoding applied inside attention kernels, could be @@ -236,7 +236,7 @@ def forward( torch.Tensor The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ - check_pos_encoding_mode(pos_encoding_mode) + _check_pos_encoding_mode(pos_encoding_mode) if logits_soft_cap is None: logits_soft_cap = 0.0 if sm_scale is None: @@ -245,21 +245,12 @@ def forward( rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 - if is_float8(q): - logging.warning( - "Our current prefill kernel implementation needs f16 input, the f8 inputs " - " are casted to f16, which could result in performance degradation." - ) - q = q.to(torch.float16) - kv_data = kv_data.to(torch.float16) - - kv_data = expand_5d(kv_data, self._kv_layout) if self._packed_mask_buf is None: return self._wrapper.forward( q, self._qo_indptr, - kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len, @@ -276,7 +267,7 @@ def forward( return self._wrapper.forward_custom_mask( q, self._qo_indptr, - kv_data, + *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len, diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index c1753652..ac2eae1e 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -16,6 +16,7 @@ import torch from enum import Enum +from typing import Optional, Tuple, Union class PosEncodingMode(Enum): @@ -29,15 +30,17 @@ class TensorLayout(Enum): HND = 1 -def expand_5d(x: torch.Tensor, kv_layout: str): +def _expand_5d(x: torch.Tensor, kv_layout: str): if not x.ndim in [4, 5]: raise ValueError("x must be 4D or 5D") if x.ndim == 4: # page_size == 1 if kv_layout == "NHD": + # (num_pages, 2, num_heads, head_dim) -> (num_pages, 2, page_size=1, num_heads, head_dim) # expand to 5D on the 3nd last dimension return x.unsqueeze(-3) elif kv_layout == "HND": + # (num_pages, 2, num_heads, head_dim) -> (num_pages, 2, num_heads, page_size=1, head_dim) # expand to 5D on the 2nd last dimension return x.unsqueeze(-2) else: @@ -45,12 +48,30 @@ def expand_5d(x: torch.Tensor, kv_layout: str): return x -def check_pos_encoding_mode(pos_encoding_mode: str): +def _expand_4d(x: torch.Tensor, kv_layout: str): + if not x.ndim in [3, 4]: + raise ValueError("x must be 3D or 4D") + if x.ndim == 3: + # page_size == 1 + if kv_layout == "NHD": + # (num_pages, num_heads, head_dim) -> (num_pages, page_size=1, num_heads, head_dim) + # expand to 4D on the 3nd last dimension + return x.unsqueeze(-3) + elif kv_layout == "HND": + # (num_pages, num_heads, head_dim) -> (num_pages, num_heads, page_size=1, head_dim) + # expand to 5D on the 2nd last dimension + return x.unsqueeze(-2) + else: + raise KeyError("Invalid kv_layout {}".format(kv_layout)) + return x + + +def _check_pos_encoding_mode(pos_encoding_mode: str): if not hasattr(PosEncodingMode, pos_encoding_mode): raise KeyError("Invalid pos_encoding_mode {}".format(pos_encoding_mode)) -def check_kv_layout(kv_layout: str): +def _check_kv_layout(kv_layout: str): if not hasattr(TensorLayout, kv_layout): raise KeyError("Invalide kv_layout {}".format(kv_layout)) @@ -64,3 +85,12 @@ def get_indptr(x: torch.Tensor): ret = torch.zeros(x.shape[0] + 1, dtype=x.dtype, device=x.device) ret[1:] = x.cumsum(0) return ret + +def _unpack_paged_kv_cache(paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], kv_layout: str) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + if isinstance(paged_kv_cache, tuple): + paged_k_cache, paged_v_cache = paged_kv_cache + return (None, _expand_4d(paged_k_cache, kv_layout), _expand_4d(paged_v_cache, kv_layout)) + elif torch.is_tensor(paged_kv_cache): + return (_expand_5d(paged_kv_cache, kv_layout), None, None) + else: + raise KeyError("Unrecongized paged_kv_cache type {}, expect a single tensor or a tuple of tensor.".format(type(paged_kv_cache))) \ No newline at end of file From 93af3b8ee56b8bcd04dda04e8d6561682d3196f4 Mon Sep 17 00:00:00 2001 From: yzh119 Date: Sat, 20 Jul 2024 01:29:21 +0000 Subject: [PATCH 3/3] upd --- docs/api/python/cascade.rst | 6 - docs/api/python/decode.rst | 6 - docs/tutorials/kv_layout.rst | 17 +- include/flashinfer/page.cuh | 12 +- python/csrc/batch_decode.cu | 2 +- python/csrc/batch_prefill.cu | 11 +- python/csrc/page.cu | 3 +- python/flashinfer/cascade.py | 62 +++--- python/flashinfer/decode.py | 42 ++-- python/flashinfer/page.py | 18 +- python/flashinfer/prefill.py | 48 +++-- python/flashinfer/sparse.py | 144 +++++++++++--- python/flashinfer/utils.py | 18 +- python/tests/test_batch_decode_kernels.py | 125 ++++++++++++ python/tests/test_batch_prefill_kernels.py | 221 +++++++++++++++++++++ python/tests/test_block_sparse.py | 27 +-- 16 files changed, 622 insertions(+), 140 deletions(-) diff --git a/docs/api/python/cascade.rst b/docs/api/python/cascade.rst index 71170aca..1475ea0e 100644 --- a/docs/api/python/cascade.rst +++ b/docs/api/python/cascade.rst @@ -22,12 +22,6 @@ Merge Attention States Cascade Attention ----------------- -.. autosummary:: - :toctree: ../../generated - - batch_decode_with_shared_prefix_padded_kv_cache - - Cascade Attention Wrapper Classes --------------------------------- diff --git a/docs/api/python/decode.rst b/docs/api/python/decode.rst index eb4d06a3..b7d68a8e 100644 --- a/docs/api/python/decode.rst +++ b/docs/api/python/decode.rst @@ -16,12 +16,6 @@ Single Request Decoding Batch Decoding -------------- -.. autosummary:: - :toctree: ../../generated - - batch_decode_with_padded_kv_cache - batch_decode_with_padded_kv_cache_return_lse - .. autoclass:: BatchDecodeWithPagedKVCacheWrapper :members: diff --git a/docs/tutorials/kv_layout.rst b/docs/tutorials/kv_layout.rst index c29edcff..45e45e8e 100644 --- a/docs/tutorials/kv_layout.rst +++ b/docs/tutorials/kv_layout.rst @@ -119,14 +119,23 @@ The overall ``kv_indptr`` array (with length ``num_requests+1``) can be computed The overall ``kv_page_indices`` array (with length ``kv_indptr[-1]``) is the concatenation of all requests' ``page_indices``. The overall ``kv_last_page_lens`` array (with length ``num_requests``) is the concatenation of all requests' ``last_page_length``. -The ``kv_data`` tensor is a 5-D tensor with shape (in ``NHD`` layout): +The ``kv_data`` tensor could either be a single 5-D tensor or a tuple of 4-D tensors, +when stored in a single tensor, ``kv_data`` has shape: -.. code:: +.. code:: python - (max_num_pages, 2, page_size, num_heads, head_dim) + (max_num_pages, 2, page_size, num_heads, head_dim) # NHD layout + (max_num_pages, 2, num_heads, page_size, head_dim) # HND layout + +when stored in a tuple of tensors, ``kv_data = (k_data, v_data)``, and each one of them has shape: + +.. code:: python + + (max_num_pages, page_size, num_heads, head_dim) # NHD layout + (max_num_pages, num_heads, page_size, head_dim) # HND layout where ``max_num_pages`` is the maximum number of pages used by all requests, ``page_size`` is the number of tokens -we fit into each page. ``2`` is the number of slots in each page (first one for keys, the second one for values). +we fit into each page. ``2`` in single tensor storage means K/V (first one for keys, the second one for values). FlashInfer APIs ~~~~~~~~~~~~~~~ diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index 986f66b3..d79a5ff0 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -132,8 +132,8 @@ struct paged_kv_t { * \note This constructor should only be used when page_storage == kIndices */ __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, - uint32_t batch_size, QKVLayout layout, DType* kv_data, DType* k_data, - DType* v_data, IdType* indices, IdType* indptr, + uint32_t batch_size, QKVLayout layout, DType* kv_data, + DType* k_data, DType* v_data, IdType* indices, IdType* indptr, IdType* last_page_len, IdType* rope_pos_offset = nullptr) : num_heads(num_heads), page_size(page_size), @@ -146,12 +146,12 @@ struct paged_kv_t { bool kv_defined = kv_data != nullptr; if (kv_defined) { stride_page = 2 * num_heads * page_size * head_dim; - k_data = kv_data; - v_data = kv_data + num_heads * page_size * head_dim; + this->k_data = kv_data; + this->v_data = kv_data + num_heads * page_size * head_dim; } else { stride_page = num_heads * page_size * head_dim; - k_data = k_data; - v_data = v_data; + this->k_data = k_data; + this->v_data = v_data; } stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim; stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim; diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index b23ae49b..5b936246 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -192,7 +192,7 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( auto q_scalar_type = q.scalar_type(); auto kv_scalar_type = paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type(); - + if (q_scalar_type == kv_scalar_type) { DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] { return DISPATCH_logits_post_hook(logits_post_hook, LOGITS_POST_HOOK, [&] { diff --git a/python/csrc/batch_prefill.cu b/python/csrc/batch_prefill.cu index d07c2aa2..47c4c6ea 100644 --- a/python/csrc/batch_prefill.cu +++ b/python/csrc/batch_prefill.cu @@ -124,13 +124,14 @@ std::vector BatchPrefillWithPagedKVCachePyTorchWrapper::Forward( num_kv_heads = paged_kv_cache->size(3); } } else { - CHECK_EQ(paged_kv_cache->size(3), head_dim); + CHECK_EQ(paged_k_cache->size(3), head_dim); + CHECK_EQ(paged_v_cache->size(3), head_dim); if (kv_layout_ == QKVLayout::kHND) { - num_kv_heads = paged_kv_cache->size(1); - page_size = paged_kv_cache->size(2); + num_kv_heads = paged_k_cache->size(1); + page_size = paged_k_cache->size(2); } else { - page_size = paged_kv_cache->size(1); - num_kv_heads = paged_kv_cache->size(2); + page_size = paged_k_cache->size(1); + num_kv_heads = paged_k_cache->size(2); } } CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads); diff --git a/python/csrc/page.cu b/python/csrc/page.cu index 586c1c98..12461c82 100644 --- a/python/csrc/page.cu +++ b/python/csrc/page.cu @@ -122,6 +122,5 @@ void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, return true; }); - TORCH_CHECK(success, "AppendPagedKVCache failed to dispatch with dtype ", - kv_scalar_dtype); + TORCH_CHECK(success, "AppendPagedKVCache failed to dispatch with dtype ", kv_scalar_dtype); } diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 58d27ecb..f8046042 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -207,7 +207,7 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: >>> unique_kv_last_page_len = torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) - >>> unique_kv_data_at_layer = [ + >>> unique_kv_cache_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) @@ -238,9 +238,9 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper: ... q = torch.randn(batch_size, num_qo_heads, head_dim).half().to("cuda:0") ... k_shared = shared_k_data_at_layer[i] ... v_shared = shared_v_data_at_layer[i] - ... unique_kv_data = unique_kv_data_at_layer[i] + ... unique_kv_cache = unique_kv_cache_at_layer[i] ... # compute batch decode attention, reuse auxiliary data structures for all layers - ... o = wrapper.forward(q, k_shared, v_shared, unique_kv_data) + ... o = wrapper.forward(q, k_shared, v_shared, unique_kv_cache) ... outputs.append(o) ... >>> # clear auxiliary data structures @@ -339,7 +339,7 @@ def forward( q: torch.Tensor, k_shared: torch.Tensor, v_shared: torch.Tensor, - unique_kv_data: torch.Tensor, + unique_kv_cache: torch.Tensor, allow_fp16_qk_reduction=False, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -362,13 +362,20 @@ def forward( ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if :attr:`kv_layout` is ``HND``. - unique_kv_data : torch.Tensor - A 5-D tensor of paged kv-cache data storing the request-independent suffix - key and value tensors, shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``NHD``, or - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``HND``. + unique_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + The request-independent suffix paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: + ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + + * a single 5-D tensor with shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, and + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and + ``paged_kv_cache[:, 1]`` is the value-cache. + allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). @@ -399,7 +406,7 @@ def forward( ) V_unique, S_unique = self._batch_decode_wrapper.forward_return_lse( q, - unique_kv_data, + unique_kv_cache, pos_encoding_mode="NONE", sm_scale=sm_scale, rope_scale=rope_scale, @@ -444,7 +451,7 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: >>> paged_kv_last_page_len= torch.tensor( ... [1, 7, 14, 4, 3, 1, 16], dtype=torch.int32, device="cuda:0" ... ) - >>> kv_data_at_layer = [ + >>> kv_cache_at_layer = [ ... torch.randn( ... max_num_pages, 2, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ... ) for _ in range(num_layers) @@ -473,12 +480,12 @@ class BatchPrefillWithSharedPrefixPagedKVCacheWrapper: >>> outputs = [] >>> for i in range(num_layers): ... q = torch.randn(nnz_qo, num_qo_heads, head_dim).half().to("cuda:0") - ... kv_data = kv_data_at_layer[i] + ... kv_cache = kv_cache_at_layer[i] ... k_shared = shared_k_data_at_layer[i] ... v_shared = shared_v_data_at_layer[i] ... # compute batch prefill attention, reuse auxiliary data structures ... o = prefill_wrapper.forward( - ... q, k_shared, v_shared, kv_data, causal=True + ... q, k_shared, v_shared, kv_cache, causal=True ... ) ... outputs.append(o) ... @@ -588,7 +595,7 @@ def forward( q: torch.Tensor, k_shared: torch.Tensor, v_shared: torch.Tensor, - unique_kv_data: torch.Tensor, + unique_kv_cache: torch.Tensor, causal: bool = True, allow_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, @@ -612,13 +619,20 @@ def forward( ``[shared_prefix_len, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, or ``[num_kv_heads, shared_prefix_len, head_dim]`` if :attr:`kv_layout` is ``HND``. - unique_kv_data : torch.Tensor - A 5-D tensor of paged kv-cache data storing the request-independent suffix - key and value tensors, shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``NHD``, or - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``HND``. + unique_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + The request-independent suffix paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: + ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + + * a single 5-D tensor with shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, and + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and + ``paged_kv_cache[:, 1]`` is the value-cache. + causal : bool Whether to apply causal mask on the attention matrix. allow_fp16_qk_reduction : bool @@ -651,7 +665,7 @@ def forward( ) V_unique, S_unique = self._batch_prefill_wrapper.forward_return_lse( q, - unique_kv_data, + unique_kv_cache, causal=causal, pos_encoding_mode="NONE", allow_fp16_qk_reduction=allow_fp16_qk_reduction, diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index abe78c06..4da2468f 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -290,7 +290,7 @@ def __init__( use_cuda_graph : bool Whether to enable CUDAGraph for batch decode attention, if enabled, the - auxiliary data structures will be stored in the provided buffers. The ``batch_size`` + auxiliary data structures will be stored as the provided buffers. The ``batch_size`` cannot change during the lifecycle of this wrapper when CUDAGraph is enabled. use_tensor_cores : bool @@ -554,12 +554,20 @@ def forward( ---------- q : torch.Tensor The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` - paged_kv_cache : torch.Tensor - A 5-D tensor of the reserved paged kv-cache data, shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``NHD``, or - ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if - :attr:`kv_layout` is ``HND``. + paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + The paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: + ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + + * a single 5-D tensor with shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, and + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and + ``paged_kv_cache[:, 1]`` is the value-cache. + pos_encoding_mode : str The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. @@ -659,12 +667,20 @@ def forward_return_lse( ---------- q : torch.Tensor The query tensor, shape: ``[batch_size, num_qo_heads, head_dim]`` - paged_kv_cache : torch.Tensor - A 5-D tensor of the reserved paged kv-cache data, shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``NHD``, or - ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if - :attr:`kv_layout` is ``HND``. + paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + The paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: + ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + + * a single 5-D tensor with shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, and + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and + ``paged_kv_cache[:, 1]`` is the value-cache. + pos_encoding_mode : str The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. diff --git a/python/flashinfer/page.py b/python/flashinfer/page.py index 236d534d..c6545941 100644 --- a/python/flashinfer/page.py +++ b/python/flashinfer/page.py @@ -54,11 +54,19 @@ def append_paged_kv_cache( append_indptr : torch.Tensor The indptr tensor of the key-value pairs to append, shape: ``[batch_size + 1]``. paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] - The 5-D tensor of the paged key-value cache, shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if - :attr:`kv_layout` is ``NHD``, or - ``[max_num_pages, 2, num_kv_heads, page_size, num_kv_heads]`` if - :attr:`kv_layout` is ``NHD``. + The paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: + ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + + * a single 5-D tensor with shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, and + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and + ``paged_kv_cache[:, 1]`` is the value-cache. + kv_indices : torch.Tensor The page indices of the paged kv-cache, shape: ``[kv_indptr[-1]]``. kv_indptr : torch.Tensor diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 07c7282b..5c9ed44c 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -552,7 +552,7 @@ def __init__( use_cuda_graph : bool Whether to enable CUDA graph capture for the prefill kernels, if enabled, the - auxiliary data structures will be stored in provided buffers. The ``batch_size`` + auxiliary data structures will be stored as provided buffers. The ``batch_size`` cannot change during the lifecycle of this wrapper when CUDAGraph is enabled. qo_indptr_buf : Optional[torch.Tensor] @@ -816,12 +816,20 @@ def forward( ---------- q : torch.Tensor The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` - paged_kv_cache : torch.Tensor - A 5-D tensor of the reserved paged kv-cache data, shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` - if :attr:`kv_layout` is ``NHD``, or - ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` - if :attr:`kv_layout` is ``HND``. + paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + The paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: + ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + + * a single 5-D tensor with shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, and + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and + ``paged_kv_cache[:, 1]`` is the value-cache. + causal : bool Whether to apply causal mask to the attention matrix. This is only effective when :attr:`custom_mask` is not provided in @@ -917,12 +925,20 @@ def forward_return_lse( ---------- q : torch.Tensor The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` - paged_kv_cache : torch.Tensor - A 5-D tensor of the reserved paged kv-cache data, shape: - ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` - if :attr:`kv_layout` is ``NHD``, or - ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if - :attr:`kv_layout` is ``HND``. + paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + The paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: + ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + + * a single 5-D tensor with shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, and + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``NHD``. Where ``paged_kv_cache[:, 0]`` is the key-cache and + ``paged_kv_cache[:, 1]`` is the value-cache. + causal : bool Whether to apply causal mask to the attention matrix. pos_encoding_mode : str @@ -1132,7 +1148,7 @@ def __init__( use_cuda_graph : bool Whether to enable CUDA graph capture for the prefill kernels, if enabled, the - auxiliary data structures will be stored in the provided buffers. + auxiliary data structures will be stored as the provided buffers. qo_indptr_buf : Optional[torch.Tensor] The user reserved GPU buffer to store the ``qo_indptr`` array, the size of the buffer @@ -1342,7 +1358,7 @@ def forward( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): - r"""Compute batch prefill/append attention between query and kv-cache stored in + r"""Compute batch prefill/append attention between query and kv-cache stored as ragged tensor. Parameters @@ -1447,7 +1463,7 @@ def forward_return_lse( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): - r"""Compute batch prefill/append attention between query and kv-cache stored in + r"""Compute batch prefill/append attention between query and kv-cache stored as ragged tensor. Return attention output and logsumexp of attention scores. Parameters diff --git a/python/flashinfer/sparse.py b/python/flashinfer/sparse.py index 3c18e94c..784e037a 100644 --- a/python/flashinfer/sparse.py +++ b/python/flashinfer/sparse.py @@ -22,7 +22,6 @@ from .quantization import segment_packbits from .utils import ( _check_pos_encoding_mode, - _check_kv_layout, _unpack_paged_kv_cache, is_float8, PosEncodingMode, @@ -42,33 +41,98 @@ raise e +def convert_bsr_mask_layout(mask: torch.Tensor, indptr: torch.Tensor): + r"""Convert mask from BSR data layout to flashinfer's flattened mask layout. + + Parameters + ---------- + mask : torch.Tensor + A boolean mask tensor with shape ``(nnz, R, C)``. + indptr : torch.Tensor + The indptr tensor in BSR format. + + Returns + ------- + flattened_mask : torch.Tensor + A flattenedd mask tensor with shape ``(nnz * R * C,)``. + """ + nnz, R, C = mask.shape + MB = len(indptr) - 1 + mask_flashinfer = torch.empty((nnz * R * C,), dtype=mask.dtype, device=mask.device) + for i in range(MB): + mask_flashinfer[indptr[i] * R * C : indptr[i + 1] * R * C] = ( + mask[indptr[i] : indptr[i + 1]].transpose(0, 1).reshape(-1) + ) + return mask_flashinfer + + class BlockSparseAttentionWrapper: + r"""Wrapper class for attention computation with a block-sparse matrix as attention mask. + The definition of block sparse matrix can be found at + `bsr_matrix `_ + in SciPy. + + This API supports any block size ``(R, C)``. + + Example + ------- + >>> import torch + >>> import flashinfer + >>> num_qo_heads = 32 + >>> num_kv_heads = 8 + >>> head_dim = 128 + >>> # allocate 128MB workspace buffer + >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0") + >>> bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(workspace_buffer) + >>> # sparse mask: [[0, 0, 1], [1, 0, 1], [0, 1, 1]] + >>> M = 3 + >>> N = 3 + >>> indptr = torch.tensor([0, 1, 3, 5], dtype=torch.int32, device="cuda:0") + >>> indices = torch.tensor([2, 0, 2, 1, 2], dtype=torch.int32, device="cuda:0") + >>> bsr_wrapper.begin_forward( + ... indptr, + ... indices, + ... M, + ... N, + ... 1, # R(block_rows)=1 + ... 1, # C(block_columns)=1 + ... num_qo_heads, + ... num_kv_heads, + ... head_dim, + ... ) + >>> q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device="cuda:0") + >>> k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0") + >>> v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device="cuda:0") + >>> o = bsr_wrapper.forward(q, k, v) + >>> # use dense implementation with attention mask for comparison + >>> mask = torch.tensor([[0, 0, 1], [1, 0, 1], [0, 1, 1]], dtype=torch.bool, device="cuda:0") + >>> o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask) + >>> torch.allclose(o, o_ref) + True + """ + def __init__( self, workspace_buffer: torch.Tensor, - kv_layout: str = "NHD", ): r"""Constructs of :class:`BlockSparseAttentionWrapper`. - Warning(Zihao): this is an experimental API and subject to change. - Parameters ---------- workspace_buffer : torch.Tensor The user reserved workspace buffer used to store auxiliary data structures, recommended size is 128MB, the device of the workspace buffer should be the same as the device of the input tensors. - - kv_layout : str - The layout of the input k/v tensors, could be either ``NHD`` or ``HND``. """ - _check_kv_layout(kv_layout) - self._kv_layout = kv_layout self._workspace_buffer = workspace_buffer self._wrapper = _kernels.BatchPrefillWithPagedKVCachePyTorchWrapper( - TensorLayout[kv_layout].value, + TensorLayout["NHD"].value, False, # use_cuda_graph ) + self.R = None + self.C = None + self.M = None + self.N = None def begin_forward( self, @@ -90,13 +154,16 @@ def begin_forward( Parameters ---------- indptr : torch.Tensor - The indptr of the block-sparse matrix, shape (MB + 1,), where MB is the number of blocks in the row dimension. + The block index pointer of the block-sparse matrix on row dimension, shape ``(MB + 1,)``, + where ``MB`` is the number of blocks in the row dimension. indices: torch.Tensor - The indices of the block-sparse matrix, shape (nnz,), where nnz is the number of non-zero blocks. + The block indices of the block-sparse matrix on column dimension, shape ``(nnz,)``, where + ``nnz`` is the number of non-zero blocks. The elements in ``indices`` array should be less then ``NB``: + the number of blocks in the column dimension. M : int - The number of rows of the block-sparse matrix, MB = ceil_div(M, R). + The number of rows of the block-sparse matrix, ``MB = ceil_div(M, R)``. N : int - The number of columns of the block-sparse matrix, NB = ceil_div(N, C). + The number of columns of the block-sparse matrix, ``NB = N // C``, ``N`` should be divisible by ``C``. R : int The number of rows in each block. C : int @@ -108,7 +175,7 @@ def begin_forward( head_dim : int The dimension of each head. mask : torch.Tensor, optional - The flattened mask tensor, shape (nnz * R * C,), where nnz is the number of non-zero blocks. + The mask tensor with shape ``(nnz, R, C,)``, where nnz is the number of non-zero blocks. If every block is full, then we don't need to provide the mask tensor. packed_mask : torch.Tensor, optional The 1D packed mask tensor, if provided, the :attr:`custom_mask` will be ignored. @@ -124,16 +191,16 @@ def begin_forward( is not equal to ``num_kv_heads``, the function will use `grouped query attention `_. """ - num_rows = len(indptr) - 1 - qo_indptr_host = R * torch.arange(num_rows + 1, dtype=torch.int32) + num_blocks_row = len(indptr) - 1 + qo_indptr_host = R * torch.arange(num_blocks_row + 1, dtype=torch.int32) qo_indptr_host[-1] = M self._qo_indptr = qo_indptr_host.to(indptr.device) row_empty = indptr[1:] == indptr[:1] if indices.max().item() * C > N: raise ValueError("indices out of bound") - last_block_pos = indices[torch.clamp(indptr[1:], min=1) - 1] - last_block_pos.masked_fill_(row_empty, 0) - last_block_len = torch.clamp(N - last_block_pos * C, max=C) + last_block_len = torch.full( + (num_blocks_row,), C, dtype=torch.int32, device=indptr.device + ) if mask is not None or packed_mask is not None: qk_indptr = _compute_page_qk_indptr( @@ -143,6 +210,8 @@ def begin_forward( C, # page_size ) if packed_mask is None and mask is not None: + # first convert BSR mask to flashinfer layout + mask = convert_bsr_mask_layout(mask, indptr) # create packed mask from mask packed_mask, qk_indptr = segment_packbits( mask.contiguous().view(-1), qk_indptr, bitorder="little" @@ -166,11 +235,16 @@ def begin_forward( ), ) + self.M = M + self.N = N + self.R = R + self.C = C + self._wrapper.begin_forward( self._workspace_buffer, self._qo_indptr, self._paged_kv_indptr_buf, - num_rows, + num_blocks_row, num_qo_heads, num_kv_heads, head_dim, @@ -186,11 +260,16 @@ def end_forward(self): self._paged_kv_last_page_len = None self._packed_mask_buf = None self._qk_indptr_buf = None + self.M = None + self.N = None + self.R = None + self.C = None def forward( self, q: torch.Tensor, - paged_kv_cache: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, pos_encoding_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, logits_soft_cap: Optional[float] = None, @@ -200,15 +279,14 @@ def forward( ): r"""Compute block-sparse attention between Q/K/V tensors. - Warning(Zihao): in the next release, paged_kv_cache will be decoupled into standalone k/v tensors, each - with shape (N, num_kv_heads, head_dim). - Parameters ---------- q : torch.Tensor - The query tensor, shape (M, num_qo_heads, head_dim). - paged_kv_cache : torch.Tensor - The key/value tensor, shape (N // C, 2, C, num_kv_heads, head_dim). + The query tensor with shape ``(M, num_qo_heads, head_dim)``. + k : torch.Tensor + The key tensor with shape ``(N, num_kv_heads, head_dim)``. + v : torch.Tensor + The value tensor with shape ``(N, num_kv_heads, head_dim)``. pos_encoding_mode : str, optional The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. @@ -246,11 +324,15 @@ def forward( if rope_theta is None: rope_theta = 1e4 + k = k.reshape(-1, self.C, *k.shape[-2:]).contiguous() + v = v.reshape(-1, self.C, *v.shape[-2:]).contiguous() if self._packed_mask_buf is None: return self._wrapper.forward( q, self._qo_indptr, - *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), + None, + k, + v, self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len, @@ -267,7 +349,9 @@ def forward( return self._wrapper.forward_custom_mask( q, self._qo_indptr, - *_unpack_paged_kv_cache(paged_kv_cache, self._kv_layout), + None, + k, + v, self._paged_kv_indptr_buf, self._paged_kv_indices_buf, self._paged_kv_last_page_len, diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index ac2eae1e..4733d4a2 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -86,11 +86,23 @@ def get_indptr(x: torch.Tensor): ret[1:] = x.cumsum(0) return ret -def _unpack_paged_kv_cache(paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], kv_layout: str) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + +def _unpack_paged_kv_cache( + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + kv_layout: str, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: if isinstance(paged_kv_cache, tuple): paged_k_cache, paged_v_cache = paged_kv_cache - return (None, _expand_4d(paged_k_cache, kv_layout), _expand_4d(paged_v_cache, kv_layout)) + return ( + None, + _expand_4d(paged_k_cache, kv_layout), + _expand_4d(paged_v_cache, kv_layout), + ) elif torch.is_tensor(paged_kv_cache): return (_expand_5d(paged_kv_cache, kv_layout), None, None) else: - raise KeyError("Unrecongized paged_kv_cache type {}, expect a single tensor or a tuple of tensor.".format(type(paged_kv_cache))) \ No newline at end of file + raise KeyError( + "Unrecongized paged_kv_cache type {}, expect a single tensor or a tuple of tensor.".format( + type(paged_kv_cache) + ) + ) diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index 3659dc3d..cc8e34c0 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -139,6 +139,128 @@ def test_batch_decode_with_paged_kv_cache( numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 97, 512]) +@pytest.mark.parametrize("page_size", [1, 8, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +@pytest.mark.parametrize("return_lse", [True, False]) +@pytest.mark.parametrize("q_dtype", [torch.float16]) +@pytest.mark.parametrize( + "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] +) +def test_batch_decode_with_tuple_paged_kv_cache( + batch_size, + kv_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + kv_layout, + pos_encoding_mode, + logits_soft_cap, + return_lse, + q_dtype, + kv_dtype, +): + q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = tuple( + ( + torch.randn(total_num_pages, num_kv_heads, page_size, head_dim).to(0) + if kv_layout == "HND" + else torch.randn(total_num_pages, page_size, num_kv_heads, head_dim).to(0) + ) + for _ in range(2) + ) + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * num_pages_per_seq + kv_indices = torch.arange(0, total_num_pages).to(0).int() + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ).to(0) + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout) + wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + "NONE", + logits_soft_cap=logits_soft_cap, + data_type=kv_dtype, + q_data_type=q_dtype, + ) + if return_lse: + o, _ = wrapper.forward_return_lse( + q, + tuple(map(lambda _: _.to(kv_dtype), kv_data)), + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + else: + o = wrapper.forward( + q, + tuple(map(lambda _: _.to(kv_dtype), kv_data)), + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + + k_cache, v_cache = kv_data + for i in range(batch_size): + perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] + perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] + qi = q[i] + ki = torch.cat( + [ + k_cache[kv_indptr[i] : kv_indptr[i + 1] - 1] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + k_cache[kv_indptr[i + 1] - 1, :, : kv_last_page_len[i]] + if kv_layout == "HND" + else k_cache[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ).to(kv_dtype) + vi = torch.cat( + [ + v_cache[kv_indptr[i] : kv_indptr[i + 1] - 1] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + v_cache[kv_indptr[i + 1] - 1, :, : kv_last_page_len[i]] + if kv_layout == "HND" + else v_cache[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ).to(kv_dtype) + o_ref_i = flashinfer.single_decode_with_kv_cache( + qi, + ki, + vi, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + o_i_np = o[i].cpu().numpy() + o_ref_i_np = o_ref_i.cpu().numpy() + numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 2048]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @@ -308,6 +430,9 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( test_batch_decode_with_paged_kv_cache( 256, 54, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16 ) + test_batch_decode_with_tuple_paged_kv_cache( + 256, 54, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16 + ) test_batch_decode_with_paged_kv_cache( 12, 2048, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16 ) diff --git a/python/tests/test_batch_prefill_kernels.py b/python/tests/test_batch_prefill_kernels.py index 081e09f9..c311a333 100644 --- a/python/tests/test_batch_prefill_kernels.py +++ b/python/tests/test_batch_prefill_kernels.py @@ -239,6 +239,224 @@ def test_batch_prefill_with_paged_kv_cache( numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [12, 17]) +@pytest.mark.parametrize("kv_len", [54, 97]) +@pytest.mark.parametrize("qo_len", [37, 17]) +@pytest.mark.parametrize("page_size", [1, 5, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("causal", [False, True]) +@pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) +@pytest.mark.parametrize("use_cuda_graph", [False, True]) +@pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) +@pytest.mark.parametrize("return_lse", [True, False]) +def test_batch_prefill_with_tuple_paged_kv_cache( + batch_size, + kv_len, + qo_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + causal, + kv_layout, + pos_encoding_mode, + use_cuda_graph, + logits_soft_cap, + return_lse, +): + q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() + q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_data = tuple( + ( + torch.randn(total_num_pages, num_kv_heads, page_size, head_dim).to(0).half() + if kv_layout == "HND" + else torch.randn(total_num_pages, page_size, num_kv_heads, head_dim) + .to(0) + .half() + ) + for _ in range(2) + ) + kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq + kv_indices_cpu = torch.arange(0, total_num_pages).int() + kv_last_page_len_cpu = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 + ) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) + if not use_cuda_graph: + q_indptr_gpu = q_indptr_cpu.to(0) + kv_indptr_gpu = kv_indptr_cpu.to(0) + kv_indices_gpu = kv_indices_cpu.to(0) + kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.begin_forward( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + ) + if return_lse: + o, _ = wrapper.forward_return_lse( + q, + kv_data, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + else: + o = wrapper.forward( + q, + kv_data, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + else: + q_indptr_buffer = torch.empty(batch_size + 1).int().to(0) + kv_indptr_buffer = torch.empty(batch_size + 1).int().to(0) + kv_indices_buffer = torch.empty(total_num_pages).int().to(0) + kv_last_page_len_buffer = torch.empty(batch_size).int().to(0) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_cuda_graph=True, + qo_indptr_buf=q_indptr_buffer, + paged_kv_indptr_buf=kv_indptr_buffer, + paged_kv_indices_buf=kv_indices_buffer, + paged_kv_last_page_len_buf=kv_last_page_len_buffer, + ) + q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len + kv_indptr_warmup = torch.arange(0, batch_size + 1).int() + kv_indices_warmup = torch.arange(0, batch_size).int() + kv_last_page_len_warmup = torch.full( + (batch_size,), page_size, dtype=torch.int32 + ) + wrapper.begin_forward( + q_indptr_warmup, + kv_indptr_warmup, + kv_indices_warmup, + kv_last_page_len_warmup, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + ) + + # warmup + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + if return_lse: + o, _ = wrapper.forward_return_lse( + q, + kv_data, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + else: + o = wrapper.forward( + q, + kv_data, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + torch.cuda.current_stream().wait_stream(s) + # capture + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + if return_lse: + o, _ = wrapper.forward_return_lse( + q, + kv_data, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + else: + o = wrapper.forward( + q, + kv_data, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + wrapper.end_forward() + + wrapper.begin_forward( + q_indptr_cpu, + kv_indptr_cpu, + kv_indices_cpu, + kv_last_page_len_cpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + ) + + g.replay() + + k_cache, v_cache = kv_data + for i in range(batch_size): + perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] + perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] + qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] + ki = torch.cat( + [ + k_cache[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + k_cache[kv_indptr_cpu[i + 1] - 1, :, : kv_last_page_len_cpu[i]] + if kv_layout == "HND" + else k_cache[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + vi = torch.cat( + [ + v_cache[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + v_cache[kv_indptr_cpu[i + 1] - 1, :, : kv_last_page_len_cpu[i]] + if kv_layout == "HND" + else v_cache[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ) + o_ref_i = flashinfer.single_prefill_with_kv_cache( + qi, + ki, + vi, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + o_i_np = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]].cpu().numpy() + o_ref_i_np = o_ref_i.cpu().numpy() + numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 97]) @pytest.mark.parametrize("qo_len", [37, 17]) @@ -525,6 +743,9 @@ def test_batch_prefill_with_ragged_kv_cache_custom_mask( test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False ) + test_batch_prefill_with_tuple_paged_kv_cache( + 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False + ) test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE", False, 0.0, False ) diff --git a/python/tests/test_block_sparse.py b/python/tests/test_block_sparse.py index 1bb9f919..c91b9044 100644 --- a/python/tests/test_block_sparse.py +++ b/python/tests/test_block_sparse.py @@ -23,21 +23,19 @@ def bsr_attention_ref( q, - kv, + k, + v, indptr, indices, mask_data, ): M = q.shape[0] - NB, _, C, H_KV, D = kv.shape - N = NB * C + N = k.shape[0] bsr = sp.sparse.bsr_matrix( (mask_data.cpu().numpy(), indices.cpu().numpy(), indptr.cpu().numpy()), shape=(M, N), ) dense_mask = torch.tensor(bsr.toarray(), dtype=bool, device=q.device) - k = kv[:, 0].reshape(-1, H_KV, D).contiguous() - v = kv[:, 1].reshape(-1, H_KV, D).contiguous() o = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=dense_mask) return o @@ -67,21 +65,13 @@ def test_block_sparse_attention( else: data_mask = torch.full((nnz, R, C), True, dtype=bool, device=0) q = torch.randn((M, num_qo_heads, head_dim), dtype=torch.float16, device=0) - kv_data = torch.randn( - (NB, 2, C, num_kv_heads, head_dim), dtype=torch.float16, device=0 - ) + k = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device=0) + v = torch.randn((N, num_kv_heads, head_dim), dtype=torch.float16, device=0) - o_ref = bsr_attention_ref(q, kv_data, indptr, indices, data_mask) + o_ref = bsr_attention_ref(q, k, v, indptr, indices, data_mask) workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device=0) sparse_attention_wrapper = flashinfer.BlockSparseAttentionWrapper(workspace_buffer) - if mask_inside_block: - mask_flashinfer_layout = torch.full((nnz * R * C,), False, dtype=bool, device=0) - for i in range(MB): - mask_flashinfer_layout[indptr[i] * R * C : indptr[i + 1] * R * C] = ( - data_mask[indptr[i] : indptr[i + 1]].transpose(0, 1).reshape(-1) - ) - sparse_attention_wrapper.begin_forward( indptr, indices, @@ -92,12 +82,11 @@ def test_block_sparse_attention( num_qo_heads, num_kv_heads, head_dim, - mask=mask_flashinfer_layout if mask_inside_block else None, + mask=data_mask if mask_inside_block else None, ) - o = sparse_attention_wrapper.forward(q, kv_data) + o = sparse_attention_wrapper.forward(q, k, v) sparse_attention_wrapper.end_forward() - print(o_ref, o) np.testing.assert_allclose(o_ref.cpu(), o.cpu(), atol=1e-2, rtol=1e-3)