Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: expose decoupled kv-cache to pytorch api #383

Merged
merged 3 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions docs/api/python/cascade.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,6 @@ Merge Attention States
Cascade Attention
-----------------

.. autosummary::
:toctree: ../../generated

batch_decode_with_shared_prefix_padded_kv_cache


Cascade Attention Wrapper Classes
---------------------------------

Expand Down
6 changes: 0 additions & 6 deletions docs/api/python/decode.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@ Single Request Decoding
Batch Decoding
--------------

.. autosummary::
:toctree: ../../generated

batch_decode_with_padded_kv_cache
batch_decode_with_padded_kv_cache_return_lse

.. autoclass:: BatchDecodeWithPagedKVCacheWrapper
:members:

Expand Down
17 changes: 13 additions & 4 deletions docs/tutorials/kv_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,23 @@ The overall ``kv_indptr`` array (with length ``num_requests+1``) can be computed
The overall ``kv_page_indices`` array (with length ``kv_indptr[-1]``) is the concatenation of all requests' ``page_indices``.
The overall ``kv_last_page_lens`` array (with length ``num_requests``) is the concatenation of all requests' ``last_page_length``.

The ``kv_data`` tensor is a 5-D tensor with shape (in ``NHD`` layout):
The ``kv_data`` tensor could either be a single 5-D tensor or a tuple of 4-D tensors,
when stored in a single tensor, ``kv_data`` has shape:

.. code::
.. code:: python

(max_num_pages, 2, page_size, num_heads, head_dim)
(max_num_pages, 2, page_size, num_heads, head_dim) # NHD layout
(max_num_pages, 2, num_heads, page_size, head_dim) # HND layout

when stored in a tuple of tensors, ``kv_data = (k_data, v_data)``, and each one of them has shape:

.. code:: python

(max_num_pages, page_size, num_heads, head_dim) # NHD layout
(max_num_pages, num_heads, page_size, head_dim) # HND layout

where ``max_num_pages`` is the maximum number of pages used by all requests, ``page_size`` is the number of tokens
we fit into each page. ``2`` is the number of slots in each page (first one for keys, the second one for values).
we fit into each page. ``2`` in single tensor storage means K/V (first one for keys, the second one for values).

FlashInfer APIs
~~~~~~~~~~~~~~~
Expand Down
42 changes: 42 additions & 0 deletions include/flashinfer/page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,48 @@ struct paged_kv_t {
last_page_len(nullptr),
rope_pos_offset(nullptr) {}

/*!
* \brief Construct a paged key-value cache
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param kv_data The flattened key-value cache
* \param k_data The flattened key cache
* \param v_data The flattened value cache
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
* \note This constructor should only be used when page_storage == kIndices
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType* kv_data,
DType* k_data, DType* v_data, IdType* indices, IdType* indptr,
IdType* last_page_len, IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
bool kv_defined = kv_data != nullptr;
if (kv_defined) {
stride_page = 2 * num_heads * page_size * head_dim;
this->k_data = kv_data;
this->v_data = kv_data + num_heads * page_size * head_dim;
} else {
stride_page = num_heads * page_size * head_dim;
this->k_data = k_data;
this->v_data = v_data;
}
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
}

/*!
* \brief Construct a paged key-value cache
* \param num_heads The number of heads
Expand Down
80 changes: 61 additions & 19 deletions python/csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,40 +105,71 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::UpdatePageLockedBufferSize(
}

std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len,
unsigned int pos_encoding_mode, float logits_soft_cap, float sm_scale, float rope_scale,
float rope_theta, bool return_lse) {
torch::Tensor q, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache, std::optional<torch::Tensor> paged_v_cache,
torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, unsigned int pos_encoding_mode, float logits_soft_cap,
float sm_scale, float rope_scale, float rope_theta, bool return_lse) {
CHECK_INPUT(q);
CHECK_INPUT(paged_kv_data);
bool paged_kv_defined = paged_kv_cache.has_value();
if (paged_kv_defined) {
CHECK_INPUT(paged_kv_cache.value());
} else {
CHECK_INPUT(paged_k_cache.value());
CHECK_INPUT(paged_v_cache.value());
}
CHECK_INPUT(paged_kv_indptr);
CHECK_INPUT(paged_kv_indices);
CHECK_INPUT(paged_kv_last_page_len);
auto device = q.device();
CHECK_EQ(paged_kv_data.device(), device);
if (paged_kv_defined) {
CHECK_EQ(paged_kv_cache->device(), device);
} else {
CHECK_EQ(paged_k_cache->device(), device);
CHECK_EQ(paged_v_cache->device(), device);
}
CHECK_EQ(paged_kv_indices.device(), device);
CHECK_EQ(paged_kv_indptr.device(), device);
CHECK_EQ(paged_kv_last_page_len.device(), device);
CHECK_DIM(3, q); // (B, H_qo, D)
CHECK_DIM(1, paged_kv_last_page_len); // (B,)
CHECK_DIM(1, paged_kv_indptr); // (B+1,)
CHECK_DIM(1, paged_kv_indices); // (nnz,)
// (num_max_pages, 2, H_kv, page_size, head_dim) for HND
// (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
CHECK_DIM(5, paged_kv_data);
if (paged_kv_defined) {
// (num_max_pages, 2, H_kv, page_size, head_dim) for HND
// (num_max_pages, 2, page_size, H_kv, head_dim) for NHD
CHECK_DIM(5, paged_kv_cache.value());
} else {
// (num_max_pages, H_kv, page_size, head_dim) for HND
// (num_max_pages, page_size, H_kv, head_dim) for NHD
CHECK_DIM(4, paged_k_cache.value());
CHECK_DIM(4, paged_v_cache.value());
}
int64_t batch_size = q.size(0);
int64_t num_qo_heads = q.size(1);
int64_t head_dim = q.size(2);
int64_t num_kv_heads, page_size;
if (kv_layout_ == QKVLayout::kHND) {
num_kv_heads = paged_kv_data.size(2);
page_size = paged_kv_data.size(3);
if (paged_kv_defined) {
CHECK_EQ(paged_kv_cache->size(1), 2);
CHECK_EQ(paged_kv_cache->size(4), head_dim);
if (kv_layout_ == QKVLayout::kHND) {
num_kv_heads = paged_kv_cache->size(2);
page_size = paged_kv_cache->size(3);
} else {
page_size = paged_kv_cache->size(2);
num_kv_heads = paged_kv_cache->size(3);
}
} else {
page_size = paged_kv_data.size(2);
num_kv_heads = paged_kv_data.size(3);
CHECK_EQ(paged_k_cache->size(3), head_dim);
CHECK_EQ(paged_v_cache->size(3), head_dim);
if (kv_layout_ == QKVLayout::kHND) {
num_kv_heads = paged_k_cache->size(1);
page_size = paged_k_cache->size(2);
} else {
page_size = paged_k_cache->size(1);
num_kv_heads = paged_k_cache->size(2);
}
}
CHECK_EQ(paged_kv_data.size(1), 2);
CHECK_EQ(paged_kv_data.size(4), head_dim);
CHECK_GE(paged_kv_indptr.size(0), batch_size + 1);
CHECK_GE(paged_kv_last_page_len.size(0), batch_size);
// TODO(Zihao): support dispatching to different data types
Expand All @@ -159,7 +190,8 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
logits_soft_cap > 0.f ? LogitsPostHook::kSoftCap : LogitsPostHook::kNone;

auto q_scalar_type = q.scalar_type();
auto kv_scalar_type = paged_kv_data.scalar_type();
auto kv_scalar_type =
paged_kv_defined ? paged_kv_cache->scalar_type() : paged_k_cache->scalar_type();

if (q_scalar_type == kv_scalar_type) {
DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q_scalar_type, qkv_type, [&] {
Expand All @@ -169,7 +201,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
paged_kv_t<PageStorage::kIndices, qkv_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
static_cast<qkv_type*>(paged_kv_data.data_ptr()),
static_cast<qkv_type*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<qkv_type*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr()
: nullptr),
static_cast<qkv_type*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr()
: nullptr),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
Expand Down Expand Up @@ -197,7 +234,12 @@ std::vector<torch::Tensor> BatchDecodeWithPagedKVCachePyTorchWrapper::Forward(
PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] {
paged_kv_t<PageStorage::kIndices, kv_type, int32_t> paged_kv(
num_kv_heads, page_size, head_dim, batch_size, kv_layout_,
static_cast<kv_type*>(paged_kv_data.data_ptr()),
static_cast<kv_type*>(paged_kv_cache.has_value() ? paged_kv_cache->data_ptr()
: nullptr),
static_cast<kv_type*>(paged_k_cache.has_value() ? paged_k_cache->data_ptr()
: nullptr),
static_cast<kv_type*>(paged_v_cache.has_value() ? paged_v_cache->data_ptr()
: nullptr),
static_cast<int32_t*>(paged_kv_indices.data_ptr()),
static_cast<int32_t*>(paged_kv_indptr.data_ptr()),
static_cast<int32_t*>(paged_kv_last_page_len.data_ptr()));
Expand Down
Loading