-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QUESTION: How to implement a tree attention with flashinfer #152
Comments
In order to support the feature of |
We also need to support MEDUSA when we use MLC-LLM again, and we have seen that Tensorrtllm supports MEDUSA |
The current implementation of Medusa in TensorRT-LLM is not fully functional, nor is it a SOTA implementation. By the way, if Medusa is not implemented based on tree mask, you can directly add a verification module at the location of model output without modifying the kernel code in the project. However, performance will be slightly worse and there will be redundant validation. |
Agree. But by using |
…ls (#266) Some speculative decoding algorithms requires tree attention, which could be supported via prefill/append attention kernels with custom attention mask. This PR supports this feature. Related issues: #152 # API Breaking Changes The `begin_forward` function in `BatchPrefillWithPagedKVCacheWrapper` now has an additional argument `page_size` to accomodate this new feature.
@zhyncs @UranusSeven we supported custom attention mask in #266, more documentations are coming. |
Cheers! |
Hi @yzh119 , thanks for your great contribution on this issue! I am willing to adopt flashinfer (w/ custom causal mask) in my current proj. However I got a small question: which value should I set in the |
Thanks for your amazing work! |
@chenzhuofu @UranusSeven Hi, thanks for your attention, yes I think setting -inf or -5e4 for masking and 0 for others is correct. Some simple examples: https://docs.flashinfer.ai/generated/flashinfer.prefill.single_prefill_with_kv_cache.html#flashinfer.prefill.single_prefill_with_kv_cache or check this test case for batch attention: flashinfer/python/tests/test_batch_prefill_kernels.py Lines 317 to 369 in 7aadc0d
using triu attention mask (fill -inf for masking and 0 for others) is equivalent to setting causal=True. |
Hi now flashinfer supports custom mask, which is a great work! But how about the positional embedding? I found #69 introducing q_position and kv_position in C++ kernels, but I didn't find a relevant python api(am i missing something?). |
Another option is to use the SparseAttentionWrapper which could greatly accelerate sparse masks. |
Hi, thanks for your awesome work!
I'm trying to implement https://github.com/SafeAILab/EAGLE with high-performance kernels. I read this blog and it says
However, I was unable to locate arguments like
position_id
(utilized for rotary embedding) andattention_mask
(for enforcing causality constraints).Could you please provide an example of implementing a tree attention model using flashinfer? Any guidance you can offer would be greatly appreciated.
The text was updated successfully, but these errors were encountered: