Skip to content

Commit

Permalink
doc: update documentation for mask layout (#270)
Browse files Browse the repository at this point in the history
Followup of #266 , this pr adds some docstring and diagrams for 2D
ragged tensor mask layout.
  • Loading branch information
yzh119 authored May 28, 2024
1 parent b16bbe4 commit c6b7c20
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
37 changes: 34 additions & 3 deletions docs/tutorials/kv_layout.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ Ragged Tensor
-------------

In batched inference/serving, the input sequence length may vary across different samples.
When there is no need to change the sequence length (e.g. in prefilling stage), we can use ``RaggedTensor`` to store
the key/value tensors in KV-Cache:
When there is no need to change the sequence length (e.g. in prefilling stage), we can use ``RaggedTensor``
with a single ragged (variable length) dimension to store the key/value tensors in KV-Cache:

.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/ragged.png
:width: 400
Expand All @@ -41,6 +41,37 @@ shape ``(indptr[-1], num_heads, head_dim)`` when the layout is ``NHD``.

We can use ``data[indptr[i]:indptr[i+1]]`` to slice the keys (or values) of request ``i``.

.. _mask-layout:

Mask Layout (2D Ragged Tensor)
------------------------------

The aforementioned Ragged Tensor can be generalized to multiple "ragged" dimensions. For example,
the attention mask in FlashInfer is a 2D ragged tensor for batch size greater than 1:

.. image:: https://raw.githubusercontent.com/flashinfer-ai/web-data/main/tutorials/mask-layout.png
:width: 800
:align: center
:alt: Data structure of Mask Layout.

When number of requests is greater than 1, different request might have different query length and kv length.
To avoid padding, we use a 2D ragged tensor to store attention mask. The input ``qo_indptr`` and
``kv_indptr`` arrays (both with length ``num_requests+1``) are used to store the information of
variable sequence lengths of each request,
``qo_indptr[i+1]-qo_indptr[i]`` is the query length of request ``i`` (``qo_len[i]``),
``kv_indptr[i+1]-kv_indptr[i]`` is the kv length of request ``i`` (``kv_len[i]``).

The mask array of all requests are flattened (with query as the first dimension, and kv as last dimension)
and concatenated into a single 1D array: ``mask_data``. FlashInfer will create a ``qk_indptr`` array implicitly
to store the start offset of each request's mask in the flattened mask array: ``qk_indptr[1:] = cumsum(qo_len * kv_len)``.

``mask_data`` has shape ``(qk_indptr[-1],)``, we can use ``mask_data[qk_indptr[i]:qk_indptr[i+1]]`` to slice the flattened
mask of request ``i``.

:class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper`
allow user to specify ``qo_indptr``, ``kv_indptr`` and custom attention mask ``custom_mask`` in ``begin_forward`` functions,
the mask data will be added to the attention score before softmax (and after softmax scaling) in the attention kernel.

.. _page-layout:

FlashInfer APIs
Expand Down Expand Up @@ -92,7 +123,7 @@ FlashInfer APIs
:meth:`flashinfer.page.append_paged_kv_cache` can append a batch of keys/values (stored as ragged tensors) to the paged KV-Cache
(the pages for these appended keys/values must be allocated prior to calling this API).

:class:`BatchDecodeWithPagedKVCacheWrapper` and :class:`BatchPrefillWithPagedKVCacheWrapper` implements the decode attention
:class:`flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper` and :class:`flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper` implements the decode attention
and prefill/append attention between queries stored in ragged tensors and keys/values stored in paged KV-Cache.

FAQ
Expand Down
16 changes: 14 additions & 2 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def single_prefill_with_kv_cache(
``HND``.
custom_mask : Optional[torch.Tensor]
The custom mask tensor, shape: ``[qo_len, kv_len]``.
If provided, the custom mask will be added to the attention matrix before
softmax and after scaling, and the :attr:`causal` parameter will be ignored.
causal : bool
Whether to apply causal mask to the attention matrix.
This is only effective when :attr:`custom_mask` is not provided.
Expand Down Expand Up @@ -201,6 +203,8 @@ def single_prefill_with_kv_cache_return_lse(
``HND``.
custom_mask : Optional[torch.Tensor]
The custom_mask tensor, shape: ``[qo_len, kv_len]``.
If provided, the custom mask will be added to the attention matrix before
softmax and after scaling, and the :attr:`causal` parameter will be ignored.
causal : bool
Whether to apply causal mask to the attention matrix.
This is only effective when :attr:`custom_mask` is not provided.
Expand Down Expand Up @@ -474,7 +478,11 @@ def begin_forward(
The size of each page in the paged kv-cache.
custom_mask : Optional[torch.Tensor]
The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``.
The mask tensor will be applied to the attention matrix before softmax if provided.
If provided, the custom mask will be added to the attention matrix before softmax
and after scaling. The mask tensor should be in the same device as the input tensors.
Please refer to the :ref:`mask layout <mask-layout>` for more details about flattened
layout of mask tensor.
Notes
-----
Expand Down Expand Up @@ -845,7 +853,11 @@ def begin_forward(
The dimension of the heads.
custom_mask : Optional[torch.Tensor]
The flattened mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``.
The mask tensor will be added to the attention matrix before softmax.
If provided, the custom mask will be added to the attention matrix before softmax
and after scaling. The mask tensor should be in the same device as the input tensors.
Please refer to the :ref:`mask layout <mask-layout>` for more details about flattened
layout of mask tensor.
Notes
-----
Expand Down

0 comments on commit c6b7c20

Please sign in to comment.