-
Notifications
You must be signed in to change notification settings - Fork 521
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add README for fbgemm genAI op (#2931)
Summary: Pull Request resolved: #2931 Differential Revision: D60360110
- Loading branch information
1 parent
f117548
commit 91613a0
Showing
1 changed file
with
113 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# FBGEMM GenAI Operators | ||
|
||
# **1. Overview** | ||
|
||
FBGEMM FP8 rowwise quantization kernels have been officially adopted in the [Llama3.1 release](https://fb.workplace.com/groups/221503021668016/permalink/1900301927121442/). FP8 has been applied across Llama3 models with 8 B, 70 B, and 405 B. Notably, for the 405 B model, FP8 enables the inference on a single node, achieving a 2x throughput improvement over the baseline BF16 running on two nodes with pipeline parallelism. Externally, it has been mentioned in [Llama3 paper](https://ai.meta.com/research/publications/the-llama-3-herd-of-models/) & repo, [HuggingFace](https://huggingface.co/docs/transformers/main/quantization/fbgemm_fp8), [vLLM](https://blog.vllm.ai/2024/07/23/llama31.html), and [TensorRT-LLM](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms/). | ||
|
||
FBGEMM GenAI FP8 supports a variety of configurations: | ||
|
||
* GEMM Operators: {hipBLASLt, CK, Triton} x {BF16, FP8} x {tensor-wise, row-wise, block-wise} x {Nvidia H100, AMD MI300x} | ||
* High/low Precision Conversion Kernels: (FP32 / BF16 <-> FP8) with scaling options {tensor-wise, row-wise, block-wise} across hardware platforms {Nvidia H100, AMD MI300x} and programming options of {Triton, CUDA/HIP}. | ||
|
||
Besides FP8 support, FBGEMM GenAI operators also support: | ||
|
||
* Customized AllReduce communications (reduce latency for small message sizes). | ||
* GQA: optimized specifically for decoding cases, as detailed in PyTorch's blog on INT4 decoding (https://pytorch.org/blog/int4-decoding/). | ||
|
||
## **1.1 FP8 core API functions** | ||
|
||
```python | ||
# Rowwise quantize (channel wise) the weight from BF16 to FP8 | ||
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w) | ||
# Rowwise quantize the activation (token wise) from BF16 to FP8 | ||
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row( | ||
x, num_tokens, activation_scale_ub | ||
) | ||
# Rowwise quantize GEMM with FP8 input and BF16 output | ||
y = torch.ops.fbgemm.f8f8bf16_rowwise( | ||
xq, | ||
wq, | ||
x_scale, | ||
w_scale, | ||
use_fast_accum=True, | ||
) | ||
``` | ||
|
||
## **1.2 How to install** | ||
|
||
```bash | ||
# Full FBGEMM library | ||
pip install fbgemm-gpu==0.8.0 | ||
pip install fbgemm-gpu==0.8.0 --index-url https://download.pytorch.org/whl/cu121 | ||
# FBGEMM library with GenAI operator only | ||
pip install fbgemm-gpu-genai | ||
pip install fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/cu121 | ||
``` | ||
|
||
# 2. **External Coverage!** | ||
|
||
## 2.1 **Llama3 Paper** | ||
|
||
https://arxiv.org/pdf/2407.21783 | ||
|
||
[[Image: https://lh7-us.googleusercontent.com/jNezF7HQlegHlgTxKnqaZaiOqZ7Kn7A5WlPYo2c9cjTLhZJIhSR3YRHUa0msQuCu04d_tElMPyjzyRwg-Drrsgqt6Swy4wyAFrm-NIGoIn6fJyDUdj3F4uGhahU0ciAyM7lw7d-gMx8jlYLgR4xWB58 | width=624.00px | height=297.33px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]] | ||
|
||
> We perform experiments leveraging the native FP8 support of H100 GPUs to perform low-precision inference. To enable low-precision inference, we apply FP8 quantization to most matrix multiplications inside the model. In particular, we quantize most parameters and activations in the feedforward network layers in the model, which account for roughly 50% of the inference compute time. We do not quantize parameters in the self-attention layers of the model. We leverage dynamic scaling factors for better accuracy (Xiao et al., 2024b), optimizing our CUDA kernels15 to reduce the overhead of calculating the scales. | ||
> Our FP8 kernels are available at https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai. We provide usage examples at https://github.com/meta-llama/llama-agentic-system. | ||
## 2.2 **Llama3 Repo** | ||
|
||
Llama Toolchain: | ||
|
||
https://github.com/meta-llama/llama-toolchain/tree/main | ||
|
||
Llama Agentic System: | ||
|
||
https://github.com/meta-llama/llama-agentic-system/tree/main?tab=readme-ov-file | ||
|
||
llm_inference: | ||
|
||
https://github.com/fairinternal/llm_inference | ||
|
||
# [[Image: https://lh7-us.googleusercontent.com/-uG9EByUNO8DS97DkR_-3APsYpcUtmYHUdlWKyonCed_8kW8Ljd2-g92NJZeAgyGGoGDd9G800UvYkfX64LvYvWgtGEVzCBkYV-RtXPiCvpigiqxsJBUNoEWicLz3ABr9f3_T8UtKSi14AcyBWWeEc8 | width=624.00px | height=208.00px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]] | ||
|
||
## 2.3 **HuggingFace** | ||
|
||
https://huggingface.co/docs/transformers/main/quantization/fbgemm_fp8 | ||
|
||
[[Image: https://lh7-us.googleusercontent.com/HSWwsJHLlEriaShYShKKNNchjeyL173Dgt3jtjiY8desbdivQjI7k1NqBKwvDDxEg5u6A5xNKgX8j_4RvC0mNO3-D-3Pr2MSTGVUeJ8Y15c1llLEwqRXuj2vrNAZGGpyRwdq4xJ-esl8ClzxxHtt7Ro | width=624.00px | height=182.67px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]] | ||
|
||
> With FBGEMM FP8 quantization method, you can quantize your model in FP8 (W8A8): | ||
> * the weights will be quantized in 8bit (FP8) per channel | ||
> * the activation will be quantized in 8bit (FP8) per token | ||
> It relies on the [FBGEMM](https://github.com/pytorch/FBGEMM) library which provides efficient low-precision general matrix multiplication for small batch sizes and support for accuracy-loss minimizing techniques such as row-wise quantization and outlier-aware quantization. | ||
## 2.4 **vLLM** | ||
|
||
https://blog.vllm.ai/2024/07/23/llama31.html | ||
|
||
[[Image: https://lh7-us.googleusercontent.com/Z9QHtQtjO_1rdQkPau_CxtXY27cO4mffTfWrnaVkRofc_AsyHTukqh4JMUweWDc-0x9AETVTbJZwIPN_fbVA7akMZJb-vNImECqPe-miLrtg4qK1VK608gkMZ0Lz15Px5ew_pD58-5NW0IoWE7-b20w | width=624.00px | height=297.33px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]] | ||
|
||
> Currently, vLLM supports the official Meta Llama 3.1 405B FP8 model quantized via **FBGEMM** by leveraging per-channel quantization in the MLP layer. In particular, each channel of the up/gate/down projections are quantized and multiplied by a static scaling factor. Combined with skipping quantization for the first and the last layer, and a static upper bound, this approach has minimal impact on the model’s accuracy. | ||
|
||
## 2.5 **TensorRT-LLM** | ||
|
||
https://github.com/NVIDIA/TensorRT-LLM/blame/5fa9436e17c2f9aeace070f49aa645d2577f676b/cpp/tensorrt_llm/common/quantTypeUtils.cuh#L47 | ||
|
||
[[Image: https://lh7-us.googleusercontent.com/Kvwseo5yqT6gwM4WK-VQ7oqF47xDSqko1HL7X326TvbQRU-W0Evf0mhoB5wKVw4N3ZO8x7ZpGKzOcFNvCgCZE-SwRyKDu9FjjQQXpCeoOtgGySfpBnPynDeOodfyOH0glBU6uDkavGo170iu32CK-vY | width=624.00px | height=126.67px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]] | ||
|
||
// Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720 | ||
|
||
![](7c7e600d-7c40-4ab4-9928-42f7a75d339aScreenshot_2024-08-01_at_11.26.31_PM.png) | ||
|
||
https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms/ | ||
|
||
> For the Llama 3.1-405B model, TensorRT-LLM has added support for FP8 quantization at a row-wise granularity level. This involves calculating a static scaling factor for each output weight channel (before execution) and a dynamic scaling factor for each token (during execution) to preserve maximum accuracy. | ||
> During the TensorRT engine build process, some complex layer fusions cannot be automatically discovered. TensorRT-LLM optimizes these using plugins that are explicitly inserted into the network graph definition at compile time to replace user-defined kernels such as the matrix multiplications from **FBGEMM** for the Llama 3.1 models. | ||
> For ease of use and deployment, TensorRT-Model-Optimizer and TensorRT-LLM optimizations are bundled together into NVIDIA NIM inference microservices. |