Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Q&A] Any palns for different dtypes for Q (query) and KV (kv-cache)? #285

Closed
ibsidorenko opened this issue Jun 5, 2024 · 4 comments · Fixed by #286
Closed

[Q&A] Any palns for different dtypes for Q (query) and KV (kv-cache)? #285

ibsidorenko opened this issue Jun 5, 2024 · 4 comments · Fixed by #286
Assignees

Comments

@ibsidorenko
Copy link
Contributor

ibsidorenko commented Jun 5, 2024

Hi, All! This is just a question of whether there are such plans or not...

Right now, Flashinfer lib requires Q (query) and KV (kv-cache) to have the same dtype.
Just an example from the code, q and paged_kv have the same DTypeIn:

template <bool partition_kv, PosEncodingMode pos_encoding_mode, uint32_t num_stages_smem,
          uint32_t tile_size_per_bdx, uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz,
          PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
          typename IdType>
__global__ void BatchDecodeWithPagedKVCacheKernel(
    DTypeIn* __restrict__ q, IdType* __restrict__ q_offset,
    paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
    kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
    DTypeOut* __restrict__ tmp, float* __restrict__ lse, float sm_scale, float rope_rcp_scale,
    float rope_rcp_theta)

Are there any plans to support different dtypes for KV-cache and Q (query)?
My personal interest is fp8 for kv-cache and fp16 for query.

Thank you in advance!
cc @yzh119

@yzh119
Copy link
Collaborator

yzh119 commented Jun 5, 2024

Thanks for your suggestions. Sure, I think we can definitely support it, and using fp8 for kv-cache and fp16 for q sounds reasonable to me.

I'll separate the DTypeIn to DTypeQ and DTypeKV in the kernel implementations, and the python APIs doesn't have to change.

@yzh119 yzh119 self-assigned this Jun 5, 2024
@Yard1
Copy link
Contributor

Yard1 commented Jun 5, 2024

Seconding this - I was actually thinking of submitting a PR myself. @yzh119 let me know if you need any help on this (from what I can tell, it should be quite straightforward).

Semi-related, can we expect fp8 support for prefill any time soon? How complicated would it be to add that?

@yzh119
Copy link
Collaborator

yzh119 commented Jun 5, 2024

let me know if you need any help on this (from what I can tell, it should be quite straightforward).

Sounds good, I would really appreciate your help!

can we expect fp8 support for prefill any time soon?

Yes we are in the last step of dealing with transposed ldmatrix for fp8 (for V matrix). It should be available soon :)

@Yard1
Copy link
Contributor

Yard1 commented Jun 5, 2024

Ok, let me see if I can get a PR going this week!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants