From 89f2c4a816ff133e09cb9fc1d7c3de43d4431ffd Mon Sep 17 00:00:00 2001 From: LinHeLurking Date: Fri, 25 Oct 2024 10:09:08 +0800 Subject: [PATCH] feat: non-contiguous query with paged kv cache (#553) ## Motivation Previously, only ragged version of prefill kernel supported non-contiguous query tensor (#404). But with paged kv cache, you have to make query tensor contiguous. Libraries like vLLM or SGLang must make query tensor contiguous before calling flashinfer kernels ([vLLM call of flashinfer](https://github.com/vllm-project/vllm/blob/b7df53cd42f3eab007b4f287c151960858e949df/vllm/attention/backends/flashinfer.py#L839), [SGLang call of flashinfer](https://github.com/sgl-project/sglang/blob/87a7cfa080cec3f123618c1429b5f998bf5d99cb/python/sglang/srt/layers/attention/flashinfer_backend.py#L236)). This PR solves it, ensuring that prefill/decode kernels with paged kv cache support non-contiguous query tensor. ## Main Changes 1. Add strides of query tensor in `BatchPrefillPagedParams` and `BatchDecodeParams`. 2. Set stride parameters before calling those kernels. 3. Modify JIT compiling templates to support new kernel parameters. 4. Add some tests. The Python interfaces remain the same. Nothing changes except it accepts non-contiguous query tensors now! --------- Signed-off-by: LinHeLurking --- flashinfer-aot/csrc_aot/batch_decode.cu | 8 +- flashinfer-aot/csrc_aot/batch_prefill.cu | 12 +-- include/flashinfer/attention/decode.cuh | 6 +- .../flashinfer/attention/decode_params.cuh | 9 ++- include/flashinfer/attention/prefill.cuh | 2 +- .../flashinfer/attention/prefill_params.cuh | 10 ++- python/flashinfer/jit/batch_decode_templ.py | 5 +- python/flashinfer/jit/batch_prefill_templ.py | 5 +- tests/test_non_contiguous_decode.py | 77 ++++++++++++++++++ tests/test_non_contiguous_prefill.py | 79 +++++++++++++++++++ 10 files changed, 196 insertions(+), 17 deletions(-) create mode 100644 tests/test_non_contiguous_decode.py diff --git a/flashinfer-aot/csrc_aot/batch_decode.cu b/flashinfer-aot/csrc_aot/batch_decode.cu index 95945a5c..48e0e6bd 100644 --- a/flashinfer-aot/csrc_aot/batch_decode.cu +++ b/flashinfer-aot/csrc_aot/batch_decode.cu @@ -128,6 +128,10 @@ std::vector BatchDecodeWithPagedKVCacheRun( auto q_scalar_type = q.scalar_type(); auto kv_scalar_type = paged_k_cache.scalar_type(); + // get q_stride_n and q_stride_h + const auto q_stride_n = q.stride(0); + const auto q_stride_h = q.stride(1); + // get kv_cache_strides const int64_t* kv_cache_strides = nullptr; auto k_strides = paged_k_cache.strides(); @@ -157,8 +161,8 @@ std::vector BatchDecodeWithPagedKVCacheRun( ParamsT params(static_cast(q.data_ptr()), /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - /*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap, - sm_scale, rope_scale, rope_theta); + /*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left, + logits_soft_cap, sm_scale, rope_scale, rope_theta); DTypeO* tmp_v = nullptr; float* tmp_s = nullptr; diff --git a/flashinfer-aot/csrc_aot/batch_prefill.cu b/flashinfer-aot/csrc_aot/batch_prefill.cu index 448f4a9f..0289269f 100644 --- a/flashinfer-aot/csrc_aot/batch_prefill.cu +++ b/flashinfer-aot/csrc_aot/batch_prefill.cu @@ -237,6 +237,10 @@ std::vector BatchPrefillWithPagedKVCacheRun( auto q_scalar_type = q.scalar_type(); auto kv_scalar_type = paged_k_cache.scalar_type(); + // get q_stride_n and q_stride_h + const auto q_stride_n = q.stride(0); + const auto q_stride_h = q.stride(1); + // get kv_cache_strides const int64_t* kv_cache_strides = nullptr; auto k_strides = paged_k_cache.strides(); @@ -254,8 +258,7 @@ std::vector BatchPrefillWithPagedKVCacheRun( paged_kv_t paged_kv( num_kv_heads, page_size, HEAD_DIM, batch_size, kv_layout, static_cast(paged_k_cache.data_ptr()), - static_cast(paged_v_cache.data_ptr()), - kv_cache_strides, + static_cast(paged_v_cache.data_ptr()), kv_cache_strides, static_cast(paged_kv_indices.data_ptr()), static_cast(paged_kv_indptr.data_ptr()), static_cast(paged_kv_last_page_len.data_ptr())); @@ -266,7 +269,6 @@ std::vector BatchPrefillWithPagedKVCacheRun( get_variant_code(/*use_custom_mask=*/MASK_MODE == MaskMode::kCustom, /*use_sliding_window=*/true, USE_LOGITS_SOFT_CAP, /*use_alibi_slopes=*/false)>; - PagedParamsT params( static_cast(q.data_ptr()), paged_kv, maybe_custom_mask.has_value() ? static_cast(maybe_custom_mask->data_ptr()) @@ -276,8 +278,8 @@ std::vector BatchPrefillWithPagedKVCacheRun( : nullptr, /*q_offset=*/nullptr, static_cast(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - /*alibi_slopes=*/nullptr, num_qo_heads, window_left, logits_soft_cap, sm_scale, - rope_scale, rope_theta); + /*alibi_slopes=*/nullptr, num_qo_heads, q_stride_n, q_stride_h, window_left, + logits_soft_cap, sm_scale, rope_scale, rope_theta); DTypeO* tmp_v = nullptr; float* tmp_s = nullptr; diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 4d99395e..da5084a3 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -439,6 +439,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ vec_t q_vec; vec_t freq; int32_t q_offset_val = q_offset == nullptr ? (kv_len - 1) : q_offset[batch_idx]; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; @@ -450,10 +452,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ } // apply rotary embedding to q matrix q_vec = vec_apply_llama_rope( - q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim, freq, q_offset_val); + q + batch_idx * q_stride_n + qo_head_idx * q_stride_h, freq, q_offset_val); } else { // do not apply rotary embedding to q matrix - q_vec.cast_load(q + (batch_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); + q_vec.cast_load(q + batch_idx * q_stride_n + qo_head_idx * q_stride_h + tx * vec_size); } #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { diff --git a/include/flashinfer/attention/decode_params.cuh b/include/flashinfer/attention/decode_params.cuh index 3dac774a..5505a33f 100644 --- a/include/flashinfer/attention/decode_params.cuh +++ b/include/flashinfer/attention/decode_params.cuh @@ -119,6 +119,8 @@ struct BatchDecodeParams { float* alibi_slopes; uint32_t padded_batch_size; uint32_t num_qo_heads; + IdType q_stride_n; + IdType q_stride_h; int32_t window_left; float logits_soft_cap; float sm_scale; @@ -135,8 +137,9 @@ struct BatchDecodeParams { __device__ __host__ BatchDecodeParams(DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, DTypeO* o, float* lse, float* alibi_slopes, uint32_t num_qo_heads, - int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta) + IdType q_stride_n, IdType q_stride_h, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) : q(q), q_offset(q_offset), paged_kv(paged_kv), @@ -145,6 +148,8 @@ struct BatchDecodeParams { alibi_slopes(alibi_slopes), padded_batch_size(0), num_qo_heads(num_qo_heads), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), window_left(window_left), logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 11df612c..925047b7 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -1867,7 +1867,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag const uint32_t qo_packed_idx_base = (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_FRAGS_Q * 16; - const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim; + const uint32_t q_stride_n = params.q_stride_n, q_stride_h = params.q_stride_h; constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B; smem_t qo_smem(smem); DTypeQ* q_ptr_base = q + get_elem_offset_impl(q_indptr[request_idx], kv_head_idx * group_size, diff --git a/include/flashinfer/attention/prefill_params.cuh b/include/flashinfer/attention/prefill_params.cuh index 60f101ab..fd7f0d80 100644 --- a/include/flashinfer/attention/prefill_params.cuh +++ b/include/flashinfer/attention/prefill_params.cuh @@ -212,6 +212,8 @@ struct BatchPrefillPagedParams { float* lse; float* alibi_slopes; uint32_t num_qo_heads; + IdType q_stride_n; + IdType q_stride_h; int32_t window_left; float logits_soft_cap; float sm_scale; @@ -232,9 +234,9 @@ struct BatchPrefillPagedParams { __host__ BatchPrefillPagedParams(DTypeQ* q, paged_kv_t paged_kv, uint8_t* custom_mask, IdType* q_indptr, IdType* qk_indptr, IdType* q_offset, DTypeO* o, float* lse, float* alibi_slopes, - uint32_t num_qo_heads, int32_t window_left, - float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta) + uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h, + int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta) : q(q), paged_kv(paged_kv), custom_mask(custom_mask), @@ -245,6 +247,8 @@ struct BatchPrefillPagedParams { lse(lse), alibi_slopes(alibi_slopes), num_qo_heads(num_qo_heads), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), window_left(window_left), logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), diff --git a/python/flashinfer/jit/batch_decode_templ.py b/python/flashinfer/jit/batch_decode_templ.py index 28b9582e..349b3e95 100644 --- a/python/flashinfer/jit/batch_decode_templ.py +++ b/python/flashinfer/jit/batch_decode_templ.py @@ -100,6 +100,9 @@ void* float_buffer = static_cast(float_workspace_buffer.data_ptr()); void* int_buffer = static_cast(int_workspace_buffer.data_ptr()); + + const auto q_stride_n = q.stride(0); + const auto q_stride_h = q.stride(1); const int64_t* kv_cache_strides = nullptr; auto k_strides = paged_k_cache.strides(); @@ -121,7 +124,7 @@ /*q_offset=*/nullptr, paged_kv, static_cast<{{ dtype_o }}*>(o.data_ptr()), /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), {% if use_alibi == "true" %}static_cast(alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); + num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); {{ dtype_o }}* tmp_v = nullptr; float* tmp_s = nullptr; diff --git a/python/flashinfer/jit/batch_prefill_templ.py b/python/flashinfer/jit/batch_prefill_templ.py index f9a2b628..f0cdf89c 100644 --- a/python/flashinfer/jit/batch_prefill_templ.py +++ b/python/flashinfer/jit/batch_prefill_templ.py @@ -195,6 +195,9 @@ void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); void* int_buffer_ptr = static_cast(int_workspace_buffer.data_ptr()); + + const auto q_stride_n = q.stride(0); + const auto q_stride_h = q.stride(1); const int64_t* kv_cache_strides = nullptr; auto k_strides = paged_k_cache.strides(); @@ -221,7 +224,7 @@ static_cast<{{ dtype_o }}*>(o.data_ptr()), /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, {% if use_alibi == "true" %}static_cast(maybe_alibi_slopes->data_ptr()){% else %}nullptr{% endif %}, - num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); + num_qo_heads, q_stride_n, q_stride_h, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); {{ dtype_o }}* tmp_v = nullptr; float* tmp_s = nullptr; diff --git a/tests/test_non_contiguous_decode.py b/tests/test_non_contiguous_decode.py new file mode 100644 index 00000000..82fbeeb0 --- /dev/null +++ b/tests/test_non_contiguous_decode.py @@ -0,0 +1,77 @@ +import torch +import pytest +import flashinfer + + +@pytest.mark.parametrize("batch_size", [1, 19, 99]) +@pytest.mark.parametrize("page_size", [1, 5]) +@pytest.mark.parametrize("seq_len", [1]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_qo_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +def test_batch_paged_decode_packed_input( + batch_size, + page_size, + seq_len, + num_kv_heads, + num_qo_heads, + head_dim, +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be a multiple of num_kv_heads") + nnz = batch_size * seq_len + num_pages_per_req = (seq_len + page_size - 1) // page_size + num_pages = batch_size * num_pages_per_req + last_page_len = (seq_len - 1) % page_size + 1 + k_cache = torch.randn( + size=(num_pages, page_size, num_kv_heads, head_dim), + dtype=torch.float16, + device="cuda:0", + ) + v_cache = torch.randn_like(k_cache) + paged_kv_cache = (k_cache, v_cache) + workspace_buffer = torch.empty( + (256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0" + ) + paged_kv_indptr = torch.tensor( + [i * num_pages_per_req for i in range(batch_size + 1)], + dtype=torch.int32, + device="cuda:0", + ) + paged_kv_indices = torch.tensor( + list(range(num_pages)), dtype=torch.int32, device="cuda:0" + ) + paged_kv_last_page_len = torch.tensor( + [last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0" + ) + + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer) + wrapper.plan( + indptr=paged_kv_indptr, + indices=paged_kv_indices, + last_page_len=paged_kv_last_page_len, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + ) + + qkv_packed = torch.randn( + size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim), + dtype=torch.float16, + device="cuda:0", + ) + qkv_split_idx = ( + num_qo_heads * head_dim, + num_kv_heads * head_dim, + num_kv_heads * head_dim, + ) + q, _, _ = qkv_packed.split(qkv_split_idx, dim=-1) + q = q.view(-1, num_qo_heads, head_dim) + o_packed = wrapper.run(q, paged_kv_cache) + o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache) + torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_batch_paged_decode_packed_input(37, 127, 1, 4, 64, 128) diff --git a/tests/test_non_contiguous_prefill.py b/tests/test_non_contiguous_prefill.py index 53f8d688..2842a656 100644 --- a/tests/test_non_contiguous_prefill.py +++ b/tests/test_non_contiguous_prefill.py @@ -96,6 +96,85 @@ def test_batch_ragged_prefill_packed_input( torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize("batch_size", [1, 19, 99]) +@pytest.mark.parametrize("page_size", [1, 5]) +@pytest.mark.parametrize("seq_len", [1, 7, 127, 257]) +@pytest.mark.parametrize("num_kv_heads", [1, 4, 8]) +@pytest.mark.parametrize("num_qo_heads", [4, 8]) +@pytest.mark.parametrize("head_dim", [64, 128, 256]) +@pytest.mark.parametrize("causal", [True, False]) +def test_batch_paged_prefill_packed_input( + batch_size, + page_size, + seq_len, + num_kv_heads, + num_qo_heads, + head_dim, + causal, +): + if num_qo_heads % num_kv_heads != 0: + pytest.skip("num_qo_heads must be a multiple of num_kv_heads") + + nnz = batch_size * seq_len + num_pages_per_req = (seq_len + page_size - 1) // page_size + num_pages = batch_size * num_pages_per_req + last_page_len = (seq_len - 1) % page_size + 1 + k_cache = torch.randn( + size=(num_pages, page_size, num_kv_heads, head_dim), + dtype=torch.float16, + device="cuda:0", + ) + v_cache = torch.randn_like(k_cache) + paged_kv_cache = (k_cache, v_cache) + workspace_buffer = torch.empty( + (256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0" + ) + qo_indptr = torch.tensor( + [i * seq_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" + ) + paged_kv_indptr = torch.tensor( + [i * num_pages_per_req for i in range(batch_size + 1)], + dtype=torch.int32, + device="cuda:0", + ) + paged_kv_indices = torch.tensor( + list(range(num_pages)), dtype=torch.int32, device="cuda:0" + ) + paged_kv_last_page_len = torch.tensor( + [last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0" + ) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer) + wrapper.plan( + qo_indptr=qo_indptr, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + page_size=page_size, + causal=causal, + ) + + qkv_packed = torch.randn( + size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim), + dtype=torch.float16, + device="cuda:0", + ) + qkv_split_idx = ( + num_qo_heads * head_dim, + num_kv_heads * head_dim, + num_kv_heads * head_dim, + ) + q, _, _ = qkv_packed.split(qkv_split_idx, dim=-1) + # pretend that we have already appended k/v to paged_kv table + q = q.view(-1, num_qo_heads, head_dim) + o_packed = wrapper.run(q, paged_kv_cache) + o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache) + torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": test_single_prefill_packed_input(127, 4, 4, 64, True) test_batch_ragged_prefill_packed_input(37, 127, 4, 4, 64, True) + test_batch_paged_prefill_packed_input(37, 5, 127, 4, 4, 64, True)