Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

doc: update documentation index #603

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/api/python/activation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.. _apiactivation:

flashinfer.activation
=====================

.. currentmodule:: flashinfer.activation

This module provides a set of activation operations for up/gate layers in transformer MLPs.

Up/Gate output activation
-------------------------

.. autosummary::
:toctree: ../../generated

silu_and_mul
gelu_tanh_and_mul
gelu_and_mul
3 changes: 3 additions & 0 deletions docs/api/python/norm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ Kernels for normalization layers.
:toctree: _generate

rmsnorm
fused_add_rmsnorm
gemma_rmsnorm
gemma_fused_add_rmsnorm
1 change: 1 addition & 0 deletions docs/api/python/page.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Append new K/V tensors to Paged KV-Cache
:toctree: ../../generated

append_paged_kv_cache
get_batch_indices_positions
6 changes: 6 additions & 0 deletions docs/api/python/rope.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,9 @@ Kernels for applying rotary embeddings.
apply_llama31_rope_inplace
apply_rope
apply_llama31_rope
apply_rope_pos_ids
apply_rope_pos_ids_inplace
apply_llama31_rope_pos_ids
apply_llama31_rope_pos_ids_inplace
apply_rope_with_cos_sin_cache
apply_rope_with_cos_sin_cache_inplace
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ FlashInfer is a library for Large Language Models that provides high-performance
api/python/gemm
api/python/norm
api/python/rope
api/python/activation
api/python/quantization
6 changes: 6 additions & 0 deletions python/flashinfer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused SiLU and Mul operation.

``silu(input[..., :hidden_size]) * input[..., hidden_size:]``

Parameters
----------
input: torch.Tensor
Expand Down Expand Up @@ -141,6 +143,8 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused GeLU Tanh and Mul operation.

``gelu(tanh(input[..., :hidden_size])) * input[..., hidden_size:]``

Parameters
----------
input: torch.Tensor
Expand Down Expand Up @@ -171,6 +175,8 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
r"""Fused GeLU and Mul operation.

``gelu(input[..., :hidden_size]) * input[..., hidden_size:]``

Parameters
----------
input: torch.Tensor
Expand Down
20 changes: 18 additions & 2 deletions python/flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def rmsnorm(
) -> torch.Tensor:
r"""Root mean square normalization.

``out[i] = (input[i] / RMS(input)) * weight[i]``

Parameters
----------
input: torch.Tensor
Expand Down Expand Up @@ -92,6 +94,12 @@ def fused_add_rmsnorm(
) -> None:
r"""Fused add root mean square normalization.

Step 1:
``residual[i] += input[i]``

Step 2:
``input[i] = (residual[i] / RMS(residual)) * weight[i]``

Parameters
----------
input: torch.Tensor
Expand Down Expand Up @@ -119,7 +127,9 @@ def gemma_rmsnorm(
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Gemma Root mean square normalization.
r"""Gemma-style root mean square normalization.

``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``

Parameters
----------
Expand Down Expand Up @@ -163,7 +173,13 @@ def _gemma_rmsnorm_fake(
def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> None:
r"""Gemma Fused add root mean square normalization.
r"""Gemma-style fused add root mean square normalization.

Step 1:
``residual[i] += input[i]``

Step 2:
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``

Parameters
----------
Expand Down
12 changes: 10 additions & 2 deletions python/flashinfer/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,15 @@ def get_batch_indices_positions(
>>> positions # the rightmost column index of each row
tensor([4, 3, 4, 2, 3, 4, 1, 2, 3, 4], device='cuda:0', dtype=torch.int32)

Notes
-----
Note
----
This function is similar to `CSR2COO <https://docs.nvidia.com/cuda/cusparse/#csr2coo>`_
conversion in cuSPARSE library, with the difference that we are converting from a ragged
tensor (which don't require a column indices array) to a COO format.

See Also
--------
append_paged_kv_cache
"""
batch_size = append_indptr.size(0) - 1
batch_indices = torch.empty((nnz,), device=append_indptr.device, dtype=torch.int32)
Expand Down Expand Up @@ -305,6 +309,10 @@ def append_paged_kv_cache(
The function assumes that the space for appended k/v have already been allocated,
which means :attr:`kv_indices`, :attr:`kv_indptr`, :attr:`kv_last_page_len` has
incorporated appended k/v.

See Also
--------
get_batch_indices_positions
"""
_check_kv_layout(kv_layout)
_append_paged_kv_cache_kernel(
Expand Down