Skip to content

Commit

Permalink
update according to recent changes from upstream and some clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
tsu-bin committed Oct 30, 2024
1 parent 478c835 commit d8f10ee
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 31 deletions.
2 changes: 1 addition & 1 deletion python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
BatchDecodeWithSharedPrefixPagedKVCacheWrapper as BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper as BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
MultiLevelCascadeAttentionWrapper as MultiLevelCascadeAttentionWrapper,
BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper,
merge_state as merge_state,
merge_state_in_place as merge_state_in_place,
merge_states as merge_states,
)
from .decode import (
BatchDecodeWithPagedKVCacheWrapper as BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper as CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
BatchDecodeMlaWithPagedKVCacheWrapper as BatchDecodeMlaWithPagedKVCacheWrapper,
single_decode_with_kv_cache as single_decode_with_kv_cache,
)
from .gemm import (
Expand Down
5 changes: 2 additions & 3 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,10 @@ def compile_batch_decode_mla_module(
*args,
verbose: bool = False,
):
gen_batch_decode_mla_cu(*args)
uri = get_batch_decode_mla_uri(*args)
uri, path = gen_batch_decode_mla_cu(*args)
return load_cuda_ops(
uri,
[FLASHINFER_GEN_SRC_DIR / f"{uri}.cu"],
[path],
verbose=verbose,
)

Expand Down
7 changes: 5 additions & 2 deletions python/flashinfer/jit/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,14 @@ def gen_batch_decode_mla_cu(*args) -> None:
gen_directory = FLASHINFER_GEN_SRC_DIR
if not os.path.exists(gen_directory):
os.makedirs(gen_directory)
file_name = f"{get_batch_decode_mla_uri(*args)}.cu"
uri = get_batch_decode_mla_uri(*args)
file_name = f"{uri}.cu"
path = gen_directory / file_name
write_if_different(
gen_directory / file_name,
path,
get_batch_decode_mla_cu_str(*args),
)
return uri, path

def get_single_prefill_cu_str(
dtype_q: torch.dtype,
Expand Down
46 changes: 21 additions & 25 deletions src/flashinfer_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -607,25 +607,23 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperMLA(
DISPATCH_head_dim(paged_kv.head_dim_ckv, HEAD_DIM_CKV, {
// fixme: head_dim_ckv(kv_lora_rank) is 8 times the size of head_dim_kpe(qk_rope_head_dim) for all MLA model (DeepSeek-V2-Lite, DeepSeek-V2.5, MiniCPM3) at the time Oct.2024
constexpr auto HEAD_DIM_KPE = HEAD_DIM_CKV/8;
// DISPATCH_head_dim(paged_kv.head_dim_kpe, HEAD_DIM_KPE, {
using ParamsT = BatchDecodeParamsMLA<DTypeQ, DTypeKV, DTypeO, IdType>;
using AttentionVariant =
ComposedAttention<ParamsT, get_variant_code(
/*use_custom_mask=*/false, /*use_sliding_window=*/true,
/*use_logits_soft_cap=*/false, /*use_alibi=*/false)>;
ParamsT params(q_nope, q_pe, q_offset, paged_kv, o, lse, num_qo_heads,
/*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale,
rope_theta);
params.request_indices = handler->GetRequestIndices<IdType>();
params.kv_tile_indices = handler->GetKVTileIndices<IdType>();
params.o_indptr = handler->GetOIndptr<IdType>();
params.kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
params.block_valid_mask = handler->GetBlockValidMask();
params.padded_batch_size = handler->GetPlanInfo().padded_batch_size;

return BatchDecodeWithPagedKVCacheDispatchedMLA<HEAD_DIM_CKV, HEAD_DIM_KPE, AttentionVariant>(
params, handler->GetTmpV<DTypeO>(), handler->GetTmpS(), stream);
// });
using ParamsT = BatchDecodeParamsMLA<DTypeQ, DTypeKV, DTypeO, IdType>;
using AttentionVariant =
ComposedAttention<ParamsT, get_variant_code(
/*use_custom_mask=*/false, /*use_sliding_window=*/true,
/*use_logits_soft_cap=*/false, /*use_alibi=*/false)>;
ParamsT params(q_nope, q_pe, q_offset, paged_kv, o, lse, num_qo_heads,
/*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale,
rope_theta);
params.request_indices = handler->GetRequestIndices<IdType>();
params.kv_tile_indices = handler->GetKVTileIndices<IdType>();
params.o_indptr = handler->GetOIndptr<IdType>();
params.kv_chunk_size_ptr = handler->GetKVChunkSizePtr<IdType>();
params.block_valid_mask = handler->GetBlockValidMask();
params.padded_batch_size = handler->GetPlanInfo().padded_batch_size;

return BatchDecodeWithPagedKVCacheDispatchedMLA<HEAD_DIM_CKV, HEAD_DIM_KPE, AttentionVariant>(
params, handler->GetTmpV<DTypeO>(), handler->GetTmpS(), stream);
});
return cudaSuccess;
}
Expand All @@ -640,12 +638,10 @@ cudaError_t BatchDecodeHandlerPlanMLA(BatchDecodeHandler* handler, void* float_b
DISPATCH_head_dim(head_dim_ckv, HEAD_DIM_CKV, {
// fixme: head_dim_ckv(kv_lora_rank) is 8 times the size of head_dim_kpe(qk_rope_head_dim) for all MLA model (DeepSeek-V2-Lite, DeepSeek-V2.5, MiniCPM3) at the time Oct.2024
constexpr auto HEAD_DIM_KPE = HEAD_DIM_CKV/8;
// DISPATCH_head_dim(head_dim_kpe, HEAD_DIM_KPE, {
return handler->PlanDispatchedMLA<HEAD_DIM_CKV, HEAD_DIM_KPE, DTypeQ, DTypeKV, DTypeO, IdType>(
float_buffer, float_workspace_size_in_bytes, int_buffer, int_workspace_size_in_bytes,
indptr_h, last_page_len_h, batch_size, num_qo_heads, page_size);
// });
});
return handler->PlanDispatchedMLA<HEAD_DIM_CKV, HEAD_DIM_KPE, DTypeQ, DTypeKV, DTypeO, IdType>(
float_buffer, float_workspace_size_in_bytes, int_buffer, int_workspace_size_in_bytes,
indptr_h, last_page_len_h, batch_size, num_qo_heads, page_size);
});
}

} // namespace flashinfer

0 comments on commit d8f10ee

Please sign in to comment.