Skip to content

Commit

Permalink
[Fix][Relax] Add the missing tree-attn func arg for KV cache creation (
Browse files Browse the repository at this point in the history
…#17345)

This PR fixes the TIRPagedKVCache construction issue, which is caused
by missing the tree-attention with paged KV cache kernel.
  • Loading branch information
MasterJH5574 authored Sep 8, 2024
1 parent 995524a commit e468426
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, head_dim, dtype), "kv_cache_debug_get_kv"),
bb.add_func(_compact_kv_copy(num_key_value_heads, head_dim, dtype, target), "kv_cache_compact_kv_copy"),
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
rope_ext_factors,
# fmt: on
# pylint: enable=line-too-long
Expand Down

0 comments on commit e468426

Please sign in to comment.