Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QUESTION: How to implement a tree attention with flashinfer #152

Closed
UranusSeven opened this issue Mar 4, 2024 · 11 comments
Closed

QUESTION: How to implement a tree attention with flashinfer #152

UranusSeven opened this issue Mar 4, 2024 · 11 comments
Assignees

Comments

@UranusSeven
Copy link

Hi, thanks for your awesome work!

I'm trying to implement https://github.com/SafeAILab/EAGLE with high-performance kernels. I read this blog and it says

FlashInfer implements prefill/append kernels for Paged KV-Cache which none of the existing libraries have done before, and it can be used to serve models in speculative decoding setting.

However, I was unable to locate arguments like position_id (utilized for rotary embedding) and attention_mask (for enforcing causality constraints).

Could you please provide an example of implementing a tree attention model using flashinfer? Any guidance you can offer would be greatly appreciated.

@yzh119 yzh119 self-assigned this Mar 4, 2024
@zhyncs
Copy link
Member

zhyncs commented Mar 5, 2024

In order to support the feature of token position disruption brought by speculative decoding, two adjustments need to be made: one is the cos/sin matrix of RoPE, and the other is replacing casual mask with tree mask. With this, it will be very convenient to implement algorithms such as Medusa, EAGLE. From the document at https://docs.flashinfer.ai/index.html, it is currently not supported yet.

@jpf888
Copy link

jpf888 commented Mar 5, 2024

We also need to support MEDUSA when we use MLC-LLM again, and we have seen that Tensorrtllm supports MEDUSA

@zhyncs
Copy link
Member

zhyncs commented Mar 5, 2024

We also need to support MEDUSA when we use MLC-LLM again, and we have seen that Tensorrtllm supports MEDUSA

The current implementation of Medusa in TensorRT-LLM is not fully functional, nor is it a SOTA implementation. By the way, if Medusa is not implemented based on tree mask, you can directly add a verification module at the location of model output without modifying the kernel code in the project. However, performance will be slightly worse and there will be redundant validation.

@UranusSeven
Copy link
Author

In order to support the feature of token position disruption brought by speculative decoding, two adjustments need to be made: one is the cos/sin matrix of RoPE, and the other is replacing casual mask with tree mask. With this, it will be very convenient to implement algorithms such as Medusa, EAGLE. From the document at https://docs.flashinfer.ai/index.html, it is currently not supported yet.

Agree. But by using BatchPrefillWithPagedKVCacheWrapper, we can kind of sidestep the whole attention mask thing by just turning one draft sequence into a batch.

yzh119 added a commit that referenced this issue May 28, 2024
…ls (#266)

Some speculative decoding algorithms requires tree attention, which
could be supported via prefill/append attention kernels with custom
attention mask.

This PR supports this feature.

Related issues: #152 

# API Breaking Changes

The `begin_forward` function in `BatchPrefillWithPagedKVCacheWrapper`
now has an additional argument `page_size` to accomodate this new
feature.
@yzh119
Copy link
Collaborator

yzh119 commented May 28, 2024

@zhyncs @UranusSeven we supported custom attention mask in #266, more documentations are coming.

@zhyncs
Copy link
Member

zhyncs commented May 28, 2024

@zhyncs @UranusSeven we supported custom attention mask in #266, more documentations are coming.

Cheers!

@chenzhuofu
Copy link

Hi @yzh119 , thanks for your great contribution on this issue! I am willing to adopt flashinfer (w/ custom causal mask) in my current proj. However I got a small question: which value should I set in the custom_mask ? I guess I should set -5e4 for masking and 0 for other positions. Am I right? :)

@UranusSeven
Copy link
Author

@zhyncs @UranusSeven we supported custom attention mask in #266, more documentations are coming.

Thanks for your amazing work!

@yzh119
Copy link
Collaborator

yzh119 commented Jun 4, 2024

@chenzhuofu @UranusSeven Hi, thanks for your attention, yes I think setting -inf or -5e4 for masking and 0 for others is correct.

Some simple examples: https://docs.flashinfer.ai/generated/flashinfer.prefill.single_prefill_with_kv_cache.html#flashinfer.prefill.single_prefill_with_kv_cache

or check this test case for batch attention:

@pytest.mark.parametrize("batch_size", [12, 17])
@pytest.mark.parametrize("kv_len", [54, 97])
@pytest.mark.parametrize("qo_len", [37, 17])
@pytest.mark.parametrize("num_kv_heads", [4])
@pytest.mark.parametrize("num_qo_heads", [4, 32])
@pytest.mark.parametrize("head_dim", [128, 256])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"])
def test_batch_prefill_with_ragged_kv_cache_custom_mask(
batch_size,
kv_len,
qo_len,
num_kv_heads,
num_qo_heads,
head_dim,
pos_encoding_mode,
):
kv_layout = "NHD"
q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half()
q_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
k = torch.randn(batch_size * kv_len, num_kv_heads, head_dim).to(0).half()
v = torch.randn(batch_size * kv_len, num_kv_heads, head_dim).to(0).half()
kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0)
wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout
)
custom_mask = (
torch.triu(
torch.full((batch_size, qo_len, kv_len), -5e4, dtype=torch.float32),
diagonal=(kv_len - qo_len + 1),
)
.reshape(-1)
.to(0)
)
# use custom mask
wrapper.begin_forward(
q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, custom_mask
)
o_custom = wrapper.forward(q, k, v, pos_encoding_mode=pos_encoding_mode)
wrapper.end_forward()
# use causal
wrapper.begin_forward(q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim)
o_causal = wrapper.forward(
q, k, v, causal=True, pos_encoding_mode=pos_encoding_mode
)
numpy.testing.assert_allclose(
o_custom.cpu().numpy(), o_causal.cpu().numpy(), rtol=1e-3, atol=1e-3
)

using triu attention mask (fill -inf for masking and 0 for others) is equivalent to setting causal=True.

@Tomorrowdawn
Copy link

Hi now flashinfer supports custom mask, which is a great work! But how about the positional embedding? I found #69 introducing q_position and kv_position in C++ kernels, but I didn't find a relevant python api(am i missing something?).

@yzh119
Copy link
Collaborator

yzh119 commented Sep 2, 2024

Another option is to use the SparseAttentionWrapper which could greatly accelerate sparse masks.

@yzh119 yzh119 closed this as completed Sep 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants