Skip to content

Commit

Permalink
feat: support MLA decode (#551)
Browse files Browse the repository at this point in the history
Hi, this PR implements MLA decode algorithm, I would love to hear your
thoughts on this design and implementation.
### The mystery Mat Absorb algorithm
In the DeepSeekV2 paper, there was no specific formulas for how to do
param matrixes absorption, but it just vaguely said that

> Fortunately, due to the associative law of matrix multiplication, we
can absorb 𝑊𝑈𝐾 into 𝑊𝑈𝑄, and 𝑊𝑈𝑉 into 𝑊𝑂

I know there were also some discussion on this, but I still can't find
convinced answer.
Here is my conclusion on this topic, Mat Absorb is only suitable for
decode, do not use Mat Absorb for prefill, which means MLA decode and
prefill should have different computation graph and different set of
params, and Mat Absorb should merge param matrixes offline, materialize
the merged param matrixes. You can find the two sets of Mat Absorb are
two einsum ops in
[test_mla_decode_kernel.py](https://github.com/flashinfer-ai/flashinfer/blob/cd74b457981d165a0afc5317e1a65a495e32b3c9/tests/test_mla_decode_kernel.py#L209).
```
# Now we merge W_UQ and W_UK (absorb W_UK into W_UQ)
# q~q_lora_rank  n~num_heads  d~qk_nope_head_dim  l~kv_lora_rank
self.W_UQ_UK = torch.einsum("q n d, l n d -> q n l", W_UQ, W_UK).flatten(start_dim=1) # [1536, 65536]

# Merge W_UV and W_O (absorb W_UV into W_O)
# l~kv_lora_rank  n~num_heads  d~v_head_dim  h~hidden_size
self.W_UV_O = torch.einsum("l n d, h n d -> n l h", W_UV, W_O).flatten(start_dim=0, end_dim=1) # [65536, 5120]
```
I'm going to state my reason below. First let me depict the original MLA
algorithm in computation graph (The final o_proj is omitted). It can be
regarded as a 128 heads / (128+64) dim MHA algorithm.
<img width="891" alt="image"
src="https://github.com/user-attachments/assets/d2410816-0898-4a4e-afcf-86ad78044237">
And after Mat Absorb, MLA become a special 128 heads / (512+64) dim MQA
algorithm, please note that the compressed_kv is used as both K and V
directly without any projection. The detailed Mat Absorb algorithm can
be found in
[test_mla_decode_kernel.py](https://github.com/flashinfer-ai/flashinfer/blob/cd74b457981d165a0afc5317e1a65a495e32b3c9/tests/test_mla_decode_kernel.py#L206),
in which `DeepseekV2AttentionVanilla` is the original DeepSeekV2 MLA
inference implementation copied from huggingface and modified slightly,
we take `DeepseekV2AttentionVanilla` as a reference to verify the
correctness our Mat Absorb implementation. The
`DeepseekV2AttentionMatAbsorbDecode` is our Mat Absorb implementation,
it has two versions of inference function(`run_proof_of_concept`), one
is implemented purely by torch, which can help you to make it clear how
the Mat Absorb version of MLA inference works, and the other uses our
new flashinfer MLA decode kernel, you can also take it as an usage
example.
<img width="980" alt="image"
src="https://github.com/user-attachments/assets/46e03c8d-666c-49f9-869b-af862602050c">

Now let's do some calculation to see the if Mat Absorb version is
performant (for the sake of convenience, we call the original MLA algo
as Vanilla version) .
```
# We calculate the number of float ops needed by the part of MLA computation graph,
# the input tensors are c_Q and cached k_pe and compressed_kv, the output tensor is the output hidden states.
# We omitted the calculation from input hidden states to c_Q and cached k_pe and compressed_kv, 
# because it's the same for both vanilla version and mat-absorb version.
def num_float_ops_vanilla(q_len, kv_len):
    return ( q_len*1536*(128*192) + # from c_Q to q_pe and q_nope, corresponding to q_b_proj
                kv_len * 512 * (128*(128+128)) + # from compressed_kv to k_nop and value_states, corresponding to kv_b_proj
                128 * (q_len*64*kv_len + q_len*128*kv_len + q_len*kv_len*128) + # 128 heads MHA
                q_len * (128*128) * 5120 ) # from MHA output to output hidden states, corresponding to o_proj
def mem_footprint_vanilla(q_len, kv_len):
    return ( q_len*1536 + 1536*(128*192) + 
                kv_len*512 + 512*(128*(128+128)) + 
                128 * ((q_len*64 + 64*kv_len) + (q_len*128 + 128*kv_len)) + 
                q_len * (128*128) + (128*128) * 5120 ) 
def num_float_ops_mat_absorb(q_len, kv_len):
    return ( q_len*1536*(128*64) + # from c_Q to q_pe, corresponding to W_QR
                q_len*1536*(128*512) + # from c_Q to q_nope, corresponding to W_UQUK
                128 * (q_len*64*kv_len + q_len*512*kv_len + q_len*kv_len*512) + # 128 heads MQA
                q_len * (128*512) * 5120 ) # from MHA output to output hidden states, corresponding to W_UV_O
def mem_footprint_mat_absorb(q_len, kv_len):
    return ( q_len*1536 + 1536*(128*64) +
                1536*(128*512) +
                128 * (q_len*64 + q_len*512) + 1*(64*kv_len + 512*kv_len) +
                q_len * (128*512) + (128*512) * 5120 ) # from MHA output to output hidden states, corresponding to W_UV_O
kv_len = 10000
print(f"prefill: num_float_ops mat_absorb vs vanilla ratio  ~ {num_float_ops_mat_absorb(kv_len, kv_len) / num_float_ops_vanilla(kv_len, kv_len)}"
print(f"prefill: mem_footprint mat_absorb vs vanilla ratio  ~ {mem_footprint_mat_absorb(kv_len, kv_len) / mem_footprint_vanilla(kv_len, kv_len)}")
print(f"decode: num_float_ops mat_absorb vs vanilla ratio  ~ {num_float_ops_mat_absorb(1, kv_len) / num_float_ops_vanilla(1, kv_len)}")
print(f"decode: mem_footprint mat_absorb vs vanilla ratio  ~ {mem_footprint_mat_absorb(1, kv_len) / mem_footprint_vanilla(1, kv_len)}")
```
The output is:
```
prefill: num_float_ops mat_absorb vs vanilla ratio  ~ 3.3602009088734754
prefill: mem_footprint mat_absorb vs vanilla ratio  ~ 2.2874373717252205
decode: num_float_ops mat_absorb vs vanilla ratio  ~ 0.010941137164898957
decode: mem_footprint mat_absorb vs vanilla ratio  ~ 1.167867978048944
```
So we can conclude from the result above, for decode case Mat Absorb
version only use about 1% computation compared to Vanilla version, and
the memory footprint is at the same level with Vanilla version, but for
prefill case, both computation and memory footprint are much higher than
Vanilla version, so there is no reason to use Mat Absorb for prefill,
but it's worth a try for decode.

### The kernel implementation design
The new MLA decode kernel actually follows the same design concept as
the current decode kernel, also reuse much of existing code base, we add
some helper functions, such as `dim3_offset` for better code
readability.
The scheduling policy is also the same as the current one, we split task
by kv-len dimension, and because the num_kv_heads is 1 now, we can't
split num_kv_heads dimension across blocks now. There is one problem
that the 128heads / (512+64)dim Q data is too large to fit into one SM's
register file or even smem, which means we can't use only one SM/block
to process one Q data, we have to tile the num_qo_heads dimension into
gridDim.y, which can cause kv-cache data movement from gmem to smem
multiple times, though this is inevitable.

### Further improvement

- Tensor-core version implementation, since current MLA models
(DeepSeek-V2-Lite, DeepSeek-V2, MiniCPM3) all have large num_qo_heads,
which is large enough to feed data into mma fragment, but in my opinion
maybe this can have limited performance improvement, because consider
the above analysis, the bottle neck is IO bandwidth not the computation
intensity.
- Load more Q head data per thread and per block. The more we load Q
head data, the less block number is needed, the less kv data movement
from gmem to smem is needed. We can add more `q_nope_vec` per thread,
also we can use smem to store more `q_nope_vec`. I would love to hear
inputs from others.


BTW, the new variable and function naming may be not follow current
convention, I'm willing to change according to your advice.

---------

Co-authored-by: tsu-bin <tsubin@gmail.com>
  • Loading branch information
tsu-bin and tsu-bin authored Nov 2, 2024
1 parent 06a922f commit 5d454ed
Show file tree
Hide file tree
Showing 18 changed files with 1,887 additions and 102 deletions.
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ if (FLASHINFER_DECODE)
target_link_libraries(bench_batch_decode PRIVATE nvbench::main decode_kernels prefill_kernels)
target_compile_options(bench_batch_decode PRIVATE -Wno-switch-bool)

message(STATUS "Compile batch mla decode kernel benchmarks.")
file(GLOB_RECURSE BENCH_DECODE_MLA_SRCS ${PROJECT_SOURCE_DIR}/src/bench_batch_decode_mla.cu)
add_executable(bench_batch_decode_mla ${BENCH_DECODE_MLA_SRCS})
target_include_directories(bench_batch_decode_mla PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(bench_batch_decode_mla PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench)
add_dependencies(bench_batch_decode_mla dispatch_inc)
target_link_libraries(bench_batch_decode_mla PRIVATE nvbench::main decode_kernels)
target_compile_options(bench_batch_decode_mla PRIVATE -Wno-switch-bool)

message(STATUS "Compile batch decode kernel tests.")
file(GLOB_RECURSE TEST_DECODE_SRCS ${PROJECT_SOURCE_DIR}/src/test_batch_decode.cu)
add_executable(test_batch_decode ${TEST_DECODE_SRCS})
Expand Down
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(FLASHINFER_FASTDEQUANT_TEST ON)
set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256 512)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
Expand Down
332 changes: 332 additions & 0 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,338 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::Par
return cudaSuccess;
}


template <uint32_t vec_size_ckv, uint32_t vec_size_kpe, uint32_t bdx, uint32_t tile_size,
typename AttentionVariant, typename T>
__device__ __forceinline__ void compute_qk_and_update_local_stat_mla(
const typename AttentionVariant::ParamsT& params,
AttentionVariant variant, const uint32_t batch_idx,
const T* ckv_smem,
const vec_t<float, vec_size_ckv>& q_nope_vec,
const T* kpe_smem,
const vec_t<float, vec_size_kpe>& q_pe_vec,
const vec_t<float, vec_size_kpe>& freq,
uint32_t kv_idx_base, uint32_t iter_base, uint32_t iter_bound,
state_t<vec_size_ckv>& st
) {
uint32_t tx = threadIdx.x, tz = threadIdx.z;
constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv;
constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe;
float s[tile_size];
float m_prev = st.m;
#pragma unroll
for (uint32_t j = 0; j < tile_size; ++j) {
vec_t<float, vec_size_ckv> ckv_vec;
ckv_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv);

vec_t<float, vec_size_kpe> kpe_vec;
kpe_vec = vec_apply_llama_rope_interleave<vec_size_kpe, bdx>(kpe_smem + j * head_dim_kpe, freq,
kv_idx_base + tz * tile_size + j);
s[j] = 0.f;
#pragma unroll
for (uint32_t i = 0; i < vec_size_ckv; ++i) {
s[j] += q_nope_vec[i] * ckv_vec[i];
}
#pragma unroll
for (uint32_t i = 0; i < vec_size_kpe; ++i) {
s[j] += q_pe_vec[i] * kpe_vec[i];
}
s[j] *= params.sm_scale;
#pragma unroll
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
s[j] += math::shfl_xor_sync(s[j], offset);
}
s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -math::inf;
st.m = max(st.m, s[j]);
}

float o_scale = math::ptx_exp2(m_prev - st.m);
st.d *= o_scale;
#pragma unroll
for (uint32_t j = 0; j < tile_size; ++j) {
s[j] = math::ptx_exp2(s[j] - st.m);
st.d += s[j];
}
#pragma unroll
for (uint32_t i = 0; i < vec_size_ckv; ++i) {
st.o[i] = st.o[i] * o_scale;
}

#pragma unroll
for (uint32_t j = 0; j < tile_size; ++j) {
vec_t<float, vec_size_ckv> v_vec;
v_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv);
#pragma unroll
for (uint32_t i = 0; i < vec_size_ckv; ++i) {
st.o[i] = st.o[i] + s[j] * v_vec[i];
}
}
}


template <uint32_t num_stages_smem, uint32_t vec_size_ckv, uint32_t vec_size_kpe,
uint32_t bdx, uint32_t bdy, uint32_t bdz, uint32_t tile_size_qo_heads,
typename AttentionVariant>
__global__ void BatchDecodeWithPagedKVCacheKernelMLA(typename AttentionVariant::ParamsT params) {

auto block = cg::this_thread_block();
using DTypeQ = typename AttentionVariant::DTypeQ;
using DTypeKV = typename AttentionVariant::DTypeKV;
using DTypeO = typename AttentionVariant::DTypeO;
using IdType = typename AttentionVariant::IdType;
const DTypeQ* q_nope = params.q_nope;
const DTypeQ* q_pe = params.q_pe;
DTypeO* o = params.o;
float* lse = params.lse;
const auto& paged_kv = params.paged_kv;
const IdType* q_offset = params.q_offset;
const bool* block_valid_mask = params.block_valid_mask;
const uint32_t num_qo_heads = params.num_qo_heads;
const float rope_rcp_scale = params.rope_rcp_scale;
const float rope_rcp_theta = params.rope_rcp_theta;
const bool partition_kv = params.partition_kv;

constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv;
constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe;
const uint32_t batch_idx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z;
const uint32_t t_offset = dim3_offset(bdy, bdx, tz, ty, tx);

// NOTE(Zihao): when CUDAGraph is enabled, we will launch more blocks than
// the actual batch size, so we need to check if the current batch is valid
if (block_valid_mask && !block_valid_mask[batch_idx]) return;
const uint32_t mapped_batch_idx = params.request_indices[batch_idx];

const uint32_t orig_seq_len = paged_kv.get_length(mapped_batch_idx);
int32_t q_offset_val = q_offset == nullptr ? (orig_seq_len - 1) : q_offset[mapped_batch_idx];

const uint32_t kv_chunk_idx_in_orig_mapped_batch = params.kv_tile_indices[batch_idx];
const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr);
const uint32_t cur_chunk_start = partition_kv ? kv_chunk_idx_in_orig_mapped_batch * kv_chunk_size : 0;
const uint32_t cur_chunk_end =
partition_kv ? min((kv_chunk_idx_in_orig_mapped_batch + 1) * kv_chunk_size, orig_seq_len) : orig_seq_len;
const uint32_t cur_chunk_len = cur_chunk_end - cur_chunk_start;

uint32_t packed_page_iter_base = paged_kv.indptr[mapped_batch_idx] * paged_kv.page_size + cur_chunk_start;
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];

constexpr uint32_t kv_iter_len = bdy * bdz;
constexpr uint32_t compute_qk_tile = bdy;

extern __attribute__((shared)) uint8_t smem[];
DTypeKV* ckv_smem = (DTypeKV*)smem;
DTypeKV* kpe_smem = (DTypeKV*)((uint8_t*)ckv_smem + num_stages_smem * kv_iter_len * head_dim_ckv * sizeof(DTypeKV));
size_t* ckv_offset_smem = (size_t*)((uint8_t*)kpe_smem + num_stages_smem * kv_iter_len * head_dim_kpe * sizeof(DTypeKV));
size_t* kpe_offset_smem = (size_t*)((uint8_t*)ckv_offset_smem + bdx*bdy*bdz * sizeof(size_t) );
float* smem_md = (float*)ckv_offset_smem;

AttentionVariant variant(params, batch_idx, smem);

vec_t<float, vec_size_ckv> q_nope_vec[tile_size_qo_heads];
vec_t<float, vec_size_kpe> q_pe_vec[tile_size_qo_heads];
state_t<vec_size_ckv> st[tile_size_qo_heads];
uint32_t qo_head_idx[tile_size_qo_heads];

vec_t<float, vec_size_kpe> freq;
#pragma unroll
for (uint32_t i = 0; i < vec_size_kpe; ++i) {
freq[i] = rope_rcp_scale *
__powf(rope_rcp_theta,
float(2 * ((tx * vec_size_kpe + i) / 2)) / float(head_dim_kpe));
}
// load q_nope and q_pe tile
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
qo_head_idx[i] = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i);
if (qo_head_idx[i] < num_qo_heads) {
q_nope_vec[i].cast_load(q_nope + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv);
q_pe_vec[i] = vec_apply_llama_rope_interleave<vec_size_kpe, bdx>(
q_pe + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_kpe, freq, q_offset_val);
}
}

// init paged-cache read offset to be used
uint32_t q, r;
paged_kv.page_size.divmod(packed_page_iter_base + t_offset, q, r);
ckv_offset_smem[t_offset] =
paged_kv.protective_get_offset_ckv(q, r, /*feat_idx*/0, last_indptr);
kpe_offset_smem[t_offset] =
paged_kv.protective_get_offset_kpe(q, r, /*feat_idx*/0, last_indptr);
block.sync();

uint32_t stage_idx = 0;
constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size_ckv * 8;
constexpr uint32_t tx_fold = vec_size_ckv / vec_size_kpe;
static_assert(num_stages_smem <= bdx);
size_t offset_bytes;
bool is_valid_range;
#pragma unroll
for (uint32_t iter = 0; iter < num_stages_smem; ++iter) {
is_valid_range = ( iter * kv_iter_len + dim2_offset(bdy, tz, ty) ) < cur_chunk_len;

offset_bytes =
ckv_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + tx * vec_size_ckv;
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
ckv_smem + ( stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty) ) * head_dim_ckv +
tx * vec_size_ckv,
paged_kv.ckv_data + offset_bytes,
is_valid_range);

offset_bytes =
kpe_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + tx / tx_fold * vec_size_ckv;
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
kpe_smem + ( stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty) ) * head_dim_kpe +
tx / tx_fold * vec_size_ckv,
paged_kv.kpe_data + offset_bytes,
is_valid_range);

cp_async::commit_group();
stage_idx = (stage_idx + 1) % num_stages_smem;
}

#pragma unroll
for (uint32_t iter = 0; iter < ceil_div(cur_chunk_len, kv_iter_len); ++iter) {
cp_async::wait_group<1 * num_stages_smem - 1>();
block.sync();
const int32_t kv_idx_base = (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
cur_chunk_start + iter * kv_iter_len;
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
compute_qk_and_update_local_stat_mla<vec_size_ckv, vec_size_kpe, bdx, compute_qk_tile, AttentionVariant>(
params, variant, mapped_batch_idx,
ckv_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile )*head_dim_ckv,
q_nope_vec[i],
kpe_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile )*head_dim_kpe,
q_pe_vec[i],
freq,
kv_idx_base,
/*iter_base*/iter * kv_iter_len, /*iter_bound*/cur_chunk_len,
st[i]);
}

if ((iter + num_stages_smem) % bdx == 0) {
uint32_t q, r;
paged_kv.page_size.divmod(packed_page_iter_base + (iter + num_stages_smem) * kv_iter_len + t_offset,
q, r);
ckv_offset_smem[t_offset] =
paged_kv.protective_get_offset_ckv(q, r, /*feat_idx*/0, last_indptr);
kpe_offset_smem[t_offset] =
paged_kv.protective_get_offset_kpe(q, r, /*feat_idx*/0, last_indptr);
}
block.sync();

is_valid_range = ( (iter + num_stages_smem) * kv_iter_len + dim2_offset(bdy, tz, ty) ) < cur_chunk_len;
offset_bytes =
ckv_offset_smem[dim3_offset(bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] +
tx * vec_size_ckv;
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
ckv_smem + ( stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty) ) * head_dim_ckv +
tx * vec_size_ckv,
paged_kv.ckv_data + offset_bytes,
is_valid_range);

offset_bytes =
kpe_offset_smem[dim3_offset(bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] +
tx / tx_fold * vec_size_ckv;
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
kpe_smem + ( stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty) ) * head_dim_kpe +
tx / tx_fold * vec_size_ckv,
paged_kv.kpe_data + offset_bytes,
is_valid_range);
cp_async::commit_group();

stage_idx = (stage_idx + 1) % num_stages_smem;
}
cp_async::wait_group<0>();
block.sync();

if (bdz != 1) {
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
if (qo_head_idx[i] < num_qo_heads)
sync_state<vec_size_ckv, bdx, bdy, bdz>(variant, st[i], (float*)smem, smem_md);
}
}

if (tz == 0) {
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
if (qo_head_idx[i] < num_qo_heads) {
st[i].normalize();
st[i].o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv);

if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx[i]] = st[i].get_lse();
}
}
}
}
}


template <uint32_t HEAD_DIM_CKV, uint32_t HEAD_DIM_KPE, typename AttentionVariant>
cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream) {
using DTypeQ = typename AttentionVariant::DTypeQ;
using DTypeKV = typename AttentionVariant::DTypeKV;
using DTypeO = typename AttentionVariant::DTypeO;
using IdType = typename AttentionVariant::IdType;
const uint32_t num_qo_heads = params.num_qo_heads;
const uint32_t padded_batch_size = params.padded_batch_size;

constexpr uint32_t vec_size_ckv = std::max(16UL / sizeof(DTypeKV), HEAD_DIM_CKV / 32UL);
constexpr uint32_t bdx = HEAD_DIM_CKV / vec_size_ckv;
constexpr uint32_t vec_size_kpe = HEAD_DIM_KPE / bdx;

constexpr uint32_t bdy = 8;
constexpr uint32_t tile_size_qo_heads = 2;
constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
const uint32_t gdy = ceil_div(num_qo_heads, qo_heads_per_block);

auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, {
const uint32_t smem_size =
NUM_STAGES_SMEM * bdy * bdz * (HEAD_DIM_CKV + HEAD_DIM_KPE) * sizeof(DTypeKV) +
std::max(num_threads * sizeof(size_t) * 2,
2 * bdy * bdz * sizeof(float));

auto kernel =
BatchDecodeWithPagedKVCacheKernelMLA<NUM_STAGES_SMEM, vec_size_ckv, vec_size_kpe,
bdx, bdy, bdz, tile_size_qo_heads, AttentionVariant>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

if (tmp_v == nullptr) {
// do not use partition-kv kernel
dim3 nblks(padded_batch_size, gdy);
dim3 nthrs(bdx, bdy, bdz);
params.partition_kv = false;
void* args[] = {(void*)&params};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
params.partition_kv = true;
auto o = params.o;
auto lse = params.lse;
params.o = tmp_v;
params.lse = tmp_s;
void* args[] = {(void*)&params};
dim3 nblks(padded_batch_size, gdy);
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM_CKV, stream));
}
});
return cudaSuccess;
}

} // namespace flashinfer

#endif // FLASHINFER_DECODE_CUH_
Loading

0 comments on commit 5d454ed

Please sign in to comment.