Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support MLA decode #551

Merged
merged 17 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ if (FLASHINFER_DECODE)
target_link_libraries(bench_batch_decode PRIVATE nvbench::main decode_kernels prefill_kernels)
target_compile_options(bench_batch_decode PRIVATE -Wno-switch-bool)

message(STATUS "Compile batch mla decode kernel benchmarks.")
file(GLOB_RECURSE BENCH_DECODE_MLA_SRCS ${PROJECT_SOURCE_DIR}/src/bench_batch_decode_mla.cu)
add_executable(bench_batch_decode_mla ${BENCH_DECODE_MLA_SRCS})
target_include_directories(bench_batch_decode_mla PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(bench_batch_decode_mla PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench)
add_dependencies(bench_batch_decode_mla dispatch_inc)
target_link_libraries(bench_batch_decode_mla PRIVATE nvbench::main decode_kernels)
target_compile_options(bench_batch_decode_mla PRIVATE -Wno-switch-bool)

message(STATUS "Compile batch decode kernel tests.")
file(GLOB_RECURSE TEST_DECODE_SRCS ${PROJECT_SOURCE_DIR}/src/test_batch_decode.cu)
add_executable(test_batch_decode ${TEST_DECODE_SRCS})
Expand Down
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(FLASHINFER_FASTDEQUANT_TEST ON)
set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256 512)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
Expand Down
332 changes: 332 additions & 0 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,338 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(typename AttentionVariant::Par
return cudaSuccess;
}


template <uint32_t vec_size_ckv, uint32_t vec_size_kpe, uint32_t bdx, uint32_t tile_size,
typename AttentionVariant, typename T>
__device__ __forceinline__ void compute_qk_and_update_local_stat_mla(
const typename AttentionVariant::ParamsT& params,
AttentionVariant variant, const uint32_t batch_idx,
const T* ckv_smem,
const vec_t<float, vec_size_ckv>& q_nope_vec,
const T* kpe_smem,
const vec_t<float, vec_size_kpe>& q_pe_vec,
const vec_t<float, vec_size_kpe>& freq,
uint32_t kv_idx_base, uint32_t iter_base, uint32_t iter_bound,
state_t<vec_size_ckv>& st
) {
uint32_t tx = threadIdx.x, tz = threadIdx.z;
constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv;
constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe;
float s[tile_size];
float m_prev = st.m;
#pragma unroll
for (uint32_t j = 0; j < tile_size; ++j) {
vec_t<float, vec_size_ckv> ckv_vec;
ckv_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv);

vec_t<float, vec_size_kpe> kpe_vec;
kpe_vec = vec_apply_llama_rope_interleave<vec_size_kpe, bdx>(kpe_smem + j * head_dim_kpe, freq,
kv_idx_base + tz * tile_size + j);
s[j] = 0.f;
#pragma unroll
for (uint32_t i = 0; i < vec_size_ckv; ++i) {
s[j] += q_nope_vec[i] * ckv_vec[i];
}
#pragma unroll
for (uint32_t i = 0; i < vec_size_kpe; ++i) {
s[j] += q_pe_vec[i] * kpe_vec[i];
}
s[j] *= params.sm_scale;
#pragma unroll
for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) {
s[j] += math::shfl_xor_sync(s[j], offset);
}
s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -math::inf;
st.m = max(st.m, s[j]);
}

float o_scale = math::ptx_exp2(m_prev - st.m);
st.d *= o_scale;
#pragma unroll
for (uint32_t j = 0; j < tile_size; ++j) {
s[j] = math::ptx_exp2(s[j] - st.m);
st.d += s[j];
}
#pragma unroll
for (uint32_t i = 0; i < vec_size_ckv; ++i) {
st.o[i] = st.o[i] * o_scale;
}

#pragma unroll
for (uint32_t j = 0; j < tile_size; ++j) {
vec_t<float, vec_size_ckv> v_vec;
v_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv);
#pragma unroll
for (uint32_t i = 0; i < vec_size_ckv; ++i) {
st.o[i] = st.o[i] + s[j] * v_vec[i];
}
}
}


template <uint32_t num_stages_smem, uint32_t vec_size_ckv, uint32_t vec_size_kpe,
uint32_t bdx, uint32_t bdy, uint32_t bdz, uint32_t tile_size_qo_heads,
typename AttentionVariant>
__global__ void BatchDecodeWithPagedKVCacheKernelMLA(typename AttentionVariant::ParamsT params) {

auto block = cg::this_thread_block();
using DTypeQ = typename AttentionVariant::DTypeQ;
using DTypeKV = typename AttentionVariant::DTypeKV;
using DTypeO = typename AttentionVariant::DTypeO;
using IdType = typename AttentionVariant::IdType;
const DTypeQ* q_nope = params.q_nope;
const DTypeQ* q_pe = params.q_pe;
DTypeO* o = params.o;
float* lse = params.lse;
const auto& paged_kv = params.paged_kv;
const IdType* q_offset = params.q_offset;
const bool* block_valid_mask = params.block_valid_mask;
const uint32_t num_qo_heads = params.num_qo_heads;
const float rope_rcp_scale = params.rope_rcp_scale;
const float rope_rcp_theta = params.rope_rcp_theta;
const bool partition_kv = params.partition_kv;

constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv;
constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe;
const uint32_t batch_idx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z;
const uint32_t t_offset = dim3_offset(bdy, bdx, tz, ty, tx);

// NOTE(Zihao): when CUDAGraph is enabled, we will launch more blocks than
// the actual batch size, so we need to check if the current batch is valid
if (block_valid_mask && !block_valid_mask[batch_idx]) return;
const uint32_t mapped_batch_idx = params.request_indices[batch_idx];

const uint32_t orig_seq_len = paged_kv.get_length(mapped_batch_idx);
int32_t q_offset_val = q_offset == nullptr ? (orig_seq_len - 1) : q_offset[mapped_batch_idx];

const uint32_t kv_chunk_idx_in_orig_mapped_batch = params.kv_tile_indices[batch_idx];
const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr);
const uint32_t cur_chunk_start = partition_kv ? kv_chunk_idx_in_orig_mapped_batch * kv_chunk_size : 0;
const uint32_t cur_chunk_end =
partition_kv ? min((kv_chunk_idx_in_orig_mapped_batch + 1) * kv_chunk_size, orig_seq_len) : orig_seq_len;
const uint32_t cur_chunk_len = cur_chunk_end - cur_chunk_start;

uint32_t packed_page_iter_base = paged_kv.indptr[mapped_batch_idx] * paged_kv.page_size + cur_chunk_start;
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];

constexpr uint32_t kv_iter_len = bdy * bdz;
constexpr uint32_t compute_qk_tile = bdy;

extern __attribute__((shared)) uint8_t smem[];
DTypeKV* ckv_smem = (DTypeKV*)smem;
DTypeKV* kpe_smem = (DTypeKV*)((uint8_t*)ckv_smem + num_stages_smem * kv_iter_len * head_dim_ckv * sizeof(DTypeKV));
size_t* ckv_offset_smem = (size_t*)((uint8_t*)kpe_smem + num_stages_smem * kv_iter_len * head_dim_kpe * sizeof(DTypeKV));
size_t* kpe_offset_smem = (size_t*)((uint8_t*)ckv_offset_smem + bdx*bdy*bdz * sizeof(size_t) );
float* smem_md = (float*)ckv_offset_smem;

AttentionVariant variant(params, batch_idx, smem);

vec_t<float, vec_size_ckv> q_nope_vec[tile_size_qo_heads];
vec_t<float, vec_size_kpe> q_pe_vec[tile_size_qo_heads];
state_t<vec_size_ckv> st[tile_size_qo_heads];
uint32_t qo_head_idx[tile_size_qo_heads];

vec_t<float, vec_size_kpe> freq;
#pragma unroll
for (uint32_t i = 0; i < vec_size_kpe; ++i) {
freq[i] = rope_rcp_scale *
__powf(rope_rcp_theta,
float(2 * ((tx * vec_size_kpe + i) / 2)) / float(head_dim_kpe));
}
// load q_nope and q_pe tile
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
qo_head_idx[i] = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i);
if (qo_head_idx[i] < num_qo_heads) {
q_nope_vec[i].cast_load(q_nope + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv);
q_pe_vec[i] = vec_apply_llama_rope_interleave<vec_size_kpe, bdx>(
q_pe + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_kpe, freq, q_offset_val);
}
}

// init paged-cache read offset to be used
uint32_t q, r;
paged_kv.page_size.divmod(packed_page_iter_base + t_offset, q, r);
ckv_offset_smem[t_offset] =
paged_kv.protective_get_offset_ckv(q, r, /*feat_idx*/0, last_indptr);
kpe_offset_smem[t_offset] =
paged_kv.protective_get_offset_kpe(q, r, /*feat_idx*/0, last_indptr);
block.sync();

uint32_t stage_idx = 0;
constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size_ckv * 8;
constexpr uint32_t tx_fold = vec_size_ckv / vec_size_kpe;
static_assert(num_stages_smem <= bdx);
size_t offset_bytes;
bool is_valid_range;
#pragma unroll
for (uint32_t iter = 0; iter < num_stages_smem; ++iter) {
is_valid_range = ( iter * kv_iter_len + dim2_offset(bdy, tz, ty) ) < cur_chunk_len;

offset_bytes =
ckv_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + tx * vec_size_ckv;
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
ckv_smem + ( stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty) ) * head_dim_ckv +
tx * vec_size_ckv,
paged_kv.ckv_data + offset_bytes,
is_valid_range);

offset_bytes =
kpe_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + tx / tx_fold * vec_size_ckv;
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
kpe_smem + ( stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty) ) * head_dim_kpe +
tx / tx_fold * vec_size_ckv,
paged_kv.kpe_data + offset_bytes,
is_valid_range);

cp_async::commit_group();
stage_idx = (stage_idx + 1) % num_stages_smem;
}

#pragma unroll
for (uint32_t iter = 0; iter < ceil_div(cur_chunk_len, kv_iter_len); ++iter) {
cp_async::wait_group<1 * num_stages_smem - 1>();
block.sync();
const int32_t kv_idx_base = (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
cur_chunk_start + iter * kv_iter_len;
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
compute_qk_and_update_local_stat_mla<vec_size_ckv, vec_size_kpe, bdx, compute_qk_tile, AttentionVariant>(
params, variant, mapped_batch_idx,
ckv_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile )*head_dim_ckv,
q_nope_vec[i],
kpe_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile )*head_dim_kpe,
q_pe_vec[i],
freq,
kv_idx_base,
/*iter_base*/iter * kv_iter_len, /*iter_bound*/cur_chunk_len,
st[i]);
}

if ((iter + num_stages_smem) % bdx == 0) {
uint32_t q, r;
paged_kv.page_size.divmod(packed_page_iter_base + (iter + num_stages_smem) * kv_iter_len + t_offset,
q, r);
ckv_offset_smem[t_offset] =
paged_kv.protective_get_offset_ckv(q, r, /*feat_idx*/0, last_indptr);
kpe_offset_smem[t_offset] =
paged_kv.protective_get_offset_kpe(q, r, /*feat_idx*/0, last_indptr);
}
block.sync();

is_valid_range = ( (iter + num_stages_smem) * kv_iter_len + dim2_offset(bdy, tz, ty) ) < cur_chunk_len;
offset_bytes =
ckv_offset_smem[dim3_offset(bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] +
tx * vec_size_ckv;
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
ckv_smem + ( stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty) ) * head_dim_ckv +
tx * vec_size_ckv,
paged_kv.ckv_data + offset_bytes,
is_valid_range);

offset_bytes =
kpe_offset_smem[dim3_offset(bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] +
tx / tx_fold * vec_size_ckv;
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kFillZero>(
kpe_smem + ( stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty) ) * head_dim_kpe +
tx / tx_fold * vec_size_ckv,
paged_kv.kpe_data + offset_bytes,
is_valid_range);
cp_async::commit_group();

stage_idx = (stage_idx + 1) % num_stages_smem;
}
cp_async::wait_group<0>();
block.sync();

if (bdz != 1) {
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
if (qo_head_idx[i] < num_qo_heads)
sync_state<vec_size_ckv, bdx, bdy, bdz>(variant, st[i], (float*)smem, smem_md);
}
}

if (tz == 0) {
#pragma unroll
for (int i = 0; i < tile_size_qo_heads; ++i) {
if (qo_head_idx[i] < num_qo_heads) {
st[i].normalize();
st[i].o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv);

if (lse != nullptr) {
lse[batch_idx * num_qo_heads + qo_head_idx[i]] = st[i].get_lse();
}
}
}
}
}


template <uint32_t HEAD_DIM_CKV, uint32_t HEAD_DIM_KPE, typename AttentionVariant>
cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(typename AttentionVariant::ParamsT params,
typename AttentionVariant::DTypeO* tmp_v,
float* tmp_s, cudaStream_t stream) {
using DTypeQ = typename AttentionVariant::DTypeQ;
using DTypeKV = typename AttentionVariant::DTypeKV;
using DTypeO = typename AttentionVariant::DTypeO;
using IdType = typename AttentionVariant::IdType;
const uint32_t num_qo_heads = params.num_qo_heads;
const uint32_t padded_batch_size = params.padded_batch_size;

constexpr uint32_t vec_size_ckv = std::max(16UL / sizeof(DTypeKV), HEAD_DIM_CKV / 32UL);
constexpr uint32_t bdx = HEAD_DIM_CKV / vec_size_ckv;
constexpr uint32_t vec_size_kpe = HEAD_DIM_KPE / bdx;

constexpr uint32_t bdy = 8;
constexpr uint32_t tile_size_qo_heads = 2;
constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
const uint32_t gdy = ceil_div(num_qo_heads, qo_heads_per_block);

auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, {
const uint32_t smem_size =
NUM_STAGES_SMEM * bdy * bdz * (HEAD_DIM_CKV + HEAD_DIM_KPE) * sizeof(DTypeKV) +
std::max(num_threads * sizeof(size_t) * 2,
2 * bdy * bdz * sizeof(float));

auto kernel =
BatchDecodeWithPagedKVCacheKernelMLA<NUM_STAGES_SMEM, vec_size_ckv, vec_size_kpe,
bdx, bdy, bdz, tile_size_qo_heads, AttentionVariant>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));

if (tmp_v == nullptr) {
// do not use partition-kv kernel
dim3 nblks(padded_batch_size, gdy);
dim3 nthrs(bdx, bdy, bdz);
params.partition_kv = false;
void* args[] = {(void*)&params};
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
} else {
// use partition-kv kernel
params.partition_kv = true;
auto o = params.o;
auto lse = params.lse;
params.o = tmp_v;
params.lse = tmp_s;
void* args[] = {(void*)&params};
dim3 nblks(padded_batch_size, gdy);
dim3 nthrs(bdx, bdy, bdz);
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse,
params.paged_kv.batch_size, num_qo_heads,
HEAD_DIM_CKV, stream));
}
});
return cudaSuccess;
}

} // namespace flashinfer

#endif // FLASHINFER_DECODE_CUH_
Loading