Skip to content

FlashInfer: Kernel Library for LLM Serving

License

Notifications You must be signed in to change notification settings

ur4t/flashinfer

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FlashInfer

Kernel Library for LLM Serving

| Blog | Documentation | Slack| Discussion Forum |

Release Documentation

FlashInfer is a library and kernel generator for Large Language Models that provides high-performance implementation of LLM GPU kernels such as FlashAttention, SparseAttention, PageAttention, Sampling, and more. FlashInfer focuses on LLM serving and inference, and delivers state-of-the-art performance across diverse scenarios.

Check our v0.2 release blog for new features!

The core features of FlashInfer include:

  1. Efficient Sparse/Dense Attention Kernels: Efficient single/batch attention for sparse(paged)/dense KV-storage on CUDA Cores and Tensor Cores (both FA2 & FA3) templates. The vector-sparse attention can achieve 90% of the bandwidth of dense kernels with same problem size.
  2. Load-Balanced Scheduling: FlashInfer decouples plan/run stage of attention computation where we schedule the computation of variable-length inputs in plan stage to alleviate load-imbalance issue.
  3. Memory Efficiency: FlashInfer offers Cascade Attention for hierical KV-Cache, and implements Head-Query fusion for accelerating Grouped-Query Attention, and efficent kernels for low-precision attention and fused-RoPE attention for compressed KV-Cache.
  4. Customizable Attention: Bring your own attention variants through JIT-compilation.
  5. CUDAGraph and torch.compile Compatibility: FlashInfer kernels can be captured by CUDAGraphs and torch.compile for low-latency inference.
  6. Efficient LLM-specific Operators: High-Performance fused kernel for Top-P, Top-K/Min-P sampling without the need to sorting.

FlashInfer support PyTorch, TVM and C++ (header-only) APIs, and can be easily integrated into existing projects.

News

  • [Dec 16, 2024] Blog Post FlashInfer 0.2 - Efficient and Customizable Kernels for LLM Inference Serving
  • [Sept 2024] We've launched a Slack workspace for Flashinfer users and developers. Join us for timely support, discussions, updates and knowledge sharing!
  • [Jan 31, 2024] Blog Post Cascade Inference: Memory-Efficient Shared Prefix Batch Decoding
  • [Jan 31, 2024] Blog Post Accelerating Self-Attentions for LLM Serving with FlashInfer

Getting Started

Using our PyTorch API is the easiest way to get started:

Installation

We provide prebuilt wheels for Linux. You can install FlashInfer with the following command:

# For CUDA 12.4 & torch 2.4
pip install flashinfer -i https://flashinfer.ai/whl/cu124/torch2.4
# For other CUDA & torch versions, please check https://docs.flashinfer.ai/installation.html

We also offer nightly-built wheels to try the latest features from the main branch:

pip install flashinfer -i https://flashinfer.ai/whl/nightly/cu124/torch2.4

Alternatively, you can build FlashInfer from source:

git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
cd flashinfer
pip install -e . -v

By default, FlashInfer uses Just-In-Time (JIT) compilation for its kernels. To pre-compile essential kernels, set the environment variable FLASHINFER_ENABLE_AOT=1 before running the installation command:

FLASHINFER_ENABLE_AOT=1 pip install -e . -v

For more details, refer to the Install from Source documentation.

Trying it out

Below is a minimal example of using FlashInfer's single-request decode/append/prefill attention kernels:

import torch
import flashinfer

kv_len = 2048
num_kv_heads = 32
head_dim = 128

k = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)
v = torch.randn(kv_len, num_kv_heads, head_dim).half().to(0)

# decode attention

num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)

o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="ROPE_LLAMA") # decode with LLaMA style RoPE on-the-fly

# append attention
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="ROPE_LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask

# prefill attention
qo_len = 2048
q = torch.randn(qo_len, num_qo_heads, head_dim).half().to(0) # prefill attention
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=False) # prefill attention without RoPE on-the-fly, do not apply causal mask

Check out documentation for usage of batch decode/append/prefill kernels and shared-prefix cascading kernels.

Run Benchmarks

We profile FlashInfer kernel performance with nvbench and you can compile and run the benchmarks with the following commands:

mkdir build
cp cmake/config.cmake build # you can modify the config.cmake to enable/disable benchmarks and change CUDA architectures
cd build
cmake ..
make -j12

You can run ./bench_{single/batch}_{prefill/decode} to benchmark the performance (e.g. ./bench_single_prefill for single-request prefill attention). ./bench_{single/batch}_{prefill/decode} --help will show you the available options.

C++ API and TVM Bindings

FlashInfer also provides C++ API and TVM bindings, please refer to documentation for more details.

Adoption

We are thrilled to share that FlashInfer is being adopted by many cutting-edge projects, including but not limited to:

Acknowledgement

FlashInfer is inspired by FlashAttention 1&2, vLLM, stream-K, cutlass and AITemplate projects.

About

FlashInfer: Kernel Library for LLM Serving

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Cuda 59.2%
  • Python 37.3%
  • C++ 1.8%
  • CMake 1.3%
  • Other 0.4%