Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

doc: update the docstring related to alibi #147

Merged
merged 1 commit into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)

o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="LLAMA") # decode with LLaMA style RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ROPE_LLAMA") # decode with LLaMA style RoPE on-the-fly

# append attention
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="ROPE_LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask

# prefill attention
qo_len = 2048
Expand Down
3 changes: 0 additions & 3 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,6 @@ def begin_forward(
The dimension of the heads
page_size : int
The page size of the paged kv cache
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
data_type : Union[str, torch.dtype]
The data type of the paged kv cache

Expand Down
12 changes: 6 additions & 6 deletions python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def single_decode_with_kv_cache(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -168,7 +168,7 @@ def batch_decode_with_padded_kv_cache(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -257,7 +257,7 @@ def batch_decode_with_padded_kv_cache_return_lse(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -456,7 +456,7 @@ def begin_forward(
The page size of the paged kv cache
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
data_type : Union[str, torch.dtype]
The data type of the paged kv cache

Expand Down Expand Up @@ -525,7 +525,7 @@ def forward(
:attr:`kv_layout` is ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down Expand Up @@ -586,7 +586,7 @@ def forward_return_lse(
:attr:`kv_layout` is ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
sm_scale : Optional[float]
The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``.
rope_scale : Optional[float]
Expand Down
12 changes: 6 additions & 6 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def single_prefill_with_kv_cache(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -191,7 +191,7 @@ def single_prefill_with_kv_cache_return_lse(
The layout of the input k/v tensors, could be either ``NHD`` or ``HND``.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -460,7 +460,7 @@ def forward(
Whether to apply causal mask to the attention matrix.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -529,7 +529,7 @@ def forward_return_lse(
Whether to apply causal mask to the attention matrix.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -744,7 +744,7 @@ def forward(
Whether to apply causal mask to the attention matrix.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down Expand Up @@ -811,7 +811,7 @@ def forward_return_lse(
Whether to apply causal mask to the attention matrix.
pos_encoding_mode : str
Whether to apply RoPE on-the-fly inside attention kernels, could be
``NONE`` or ``LLAMA`` (LLAMA style rotary embedding).
``NONE``/``ROPE_LLAMA``(LLAMA style rotary embedding)/``ALIBI``.
allow_fp16_qk_reduction : bool
Whether to use f16 for qk reduction (faster at the cost of slight precision
loss).
Expand Down