Skip to content

Commit

Permalink
Add README for fbgemm genAI op (#2931)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2931

Differential Revision: D60360110
  • Loading branch information
jianyuh authored and facebook-github-bot committed Aug 2, 2024
1 parent f117548 commit 91613a0
Showing 1 changed file with 113 additions and 0 deletions.
113 changes: 113 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# FBGEMM GenAI Operators

# **1. Overview**

FBGEMM FP8 rowwise quantization kernels have been officially adopted in the [Llama3.1 release](https://fb.workplace.com/groups/221503021668016/permalink/1900301927121442/). FP8 has been applied across Llama3 models with 8 B, 70 B, and 405 B. Notably, for the 405 B model, FP8 enables the inference on a single node, achieving a 2x throughput improvement over the baseline BF16 running on two nodes with pipeline parallelism. Externally, it has been mentioned in [Llama3 paper](https://ai.meta.com/research/publications/the-llama-3-herd-of-models/) & repo, [HuggingFace](https://huggingface.co/docs/transformers/main/quantization/fbgemm_fp8), [vLLM](https://blog.vllm.ai/2024/07/23/llama31.html), and [TensorRT-LLM](https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms/).

FBGEMM GenAI FP8 supports a variety of configurations:

* GEMM Operators: {hipBLASLt, CK, Triton} x {BF16, FP8} x {tensor-wise, row-wise, block-wise} x {Nvidia H100, AMD MI300x}
* High/low Precision Conversion Kernels: (FP32 / BF16 <-> FP8) with scaling options {tensor-wise, row-wise, block-wise} across hardware platforms {Nvidia H100, AMD MI300x} and programming options of {Triton, CUDA/HIP}.

Besides FP8 support, FBGEMM GenAI operators also support:

* Customized AllReduce communications (reduce latency for small message sizes).
* GQA: optimized specifically for decoding cases, as detailed in PyTorch's blog on INT4 decoding (https://pytorch.org/blog/int4-decoding/).

## **1.1 FP8 core API functions**

```python
# Rowwise quantize (channel wise) the weight from BF16 to FP8
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_row(w)
# Rowwise quantize the activation (token wise) from BF16 to FP8
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
x, num_tokens, activation_scale_ub
)
# Rowwise quantize GEMM with FP8 input and BF16 output
y = torch.ops.fbgemm.f8f8bf16_rowwise(
xq,
wq,
x_scale,
w_scale,
use_fast_accum=True,
)
```

## **1.2 How to install**

```bash
# Full FBGEMM library
pip install fbgemm-gpu==0.8.0
pip install fbgemm-gpu==0.8.0 --index-url https://download.pytorch.org/whl/cu121
# FBGEMM library with GenAI operator only
pip install fbgemm-gpu-genai
pip install fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/cu121
```

# 2. **External Coverage!**

## 2.1 **Llama3 Paper**

https://arxiv.org/pdf/2407.21783

[[Image: https://lh7-us.googleusercontent.com/jNezF7HQlegHlgTxKnqaZaiOqZ7Kn7A5WlPYo2c9cjTLhZJIhSR3YRHUa0msQuCu04d_tElMPyjzyRwg-Drrsgqt6Swy4wyAFrm-NIGoIn6fJyDUdj3F4uGhahU0ciAyM7lw7d-gMx8jlYLgR4xWB58 | width=624.00px | height=297.33px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]]

> We perform experiments leveraging the native FP8 support of H100 GPUs to perform low-precision inference. To enable low-precision inference, we apply FP8 quantization to most matrix multiplications inside the model. In particular, we quantize most parameters and activations in the feedforward network layers in the model, which account for roughly 50% of the inference compute time. We do not quantize parameters in the self-attention layers of the model. We leverage dynamic scaling factors for better accuracy (Xiao et al., 2024b), optimizing our CUDA kernels15 to reduce the overhead of calculating the scales.
> Our FP8 kernels are available at https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu/experimental/gen_ai. We provide usage examples at https://github.com/meta-llama/llama-agentic-system.
## 2.2 **Llama3 Repo**

Llama Toolchain:

https://github.com/meta-llama/llama-toolchain/tree/main

Llama Agentic System:

https://github.com/meta-llama/llama-agentic-system/tree/main?tab=readme-ov-file

llm_inference:

https://github.com/fairinternal/llm_inference

# [[Image: https://lh7-us.googleusercontent.com/-uG9EByUNO8DS97DkR_-3APsYpcUtmYHUdlWKyonCed_8kW8Ljd2-g92NJZeAgyGGoGDd9G800UvYkfX64LvYvWgtGEVzCBkYV-RtXPiCvpigiqxsJBUNoEWicLz3ABr9f3_T8UtKSi14AcyBWWeEc8 | width=624.00px | height=208.00px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]]

## 2.3 **HuggingFace**

https://huggingface.co/docs/transformers/main/quantization/fbgemm_fp8

[[Image: https://lh7-us.googleusercontent.com/HSWwsJHLlEriaShYShKKNNchjeyL173Dgt3jtjiY8desbdivQjI7k1NqBKwvDDxEg5u6A5xNKgX8j_4RvC0mNO3-D-3Pr2MSTGVUeJ8Y15c1llLEwqRXuj2vrNAZGGpyRwdq4xJ-esl8ClzxxHtt7Ro | width=624.00px | height=182.67px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]]

> With FBGEMM FP8 quantization method, you can quantize your model in FP8 (W8A8):
> * the weights will be quantized in 8bit (FP8) per channel
> * the activation will be quantized in 8bit (FP8) per token
> It relies on the [FBGEMM](https://github.com/pytorch/FBGEMM) library which provides efficient low-precision general matrix multiplication for small batch sizes and support for accuracy-loss minimizing techniques such as row-wise quantization and outlier-aware quantization.
## 2.4 **vLLM**

https://blog.vllm.ai/2024/07/23/llama31.html

[[Image: https://lh7-us.googleusercontent.com/Z9QHtQtjO_1rdQkPau_CxtXY27cO4mffTfWrnaVkRofc_AsyHTukqh4JMUweWDc-0x9AETVTbJZwIPN_fbVA7akMZJb-vNImECqPe-miLrtg4qK1VK608gkMZ0Lz15Px5ew_pD58-5NW0IoWE7-b20w | width=624.00px | height=297.33px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]]

> Currently, vLLM supports the official Meta Llama 3.1 405B FP8 model quantized via **FBGEMM** by leveraging per-channel quantization in the MLP layer. In particular, each channel of the up/gate/down projections are quantized and multiplied by a static scaling factor. Combined with skipping quantization for the first and the last layer, and a static upper bound, this approach has minimal impact on the model’s accuracy.

## 2.5 **TensorRT-LLM**

https://github.com/NVIDIA/TensorRT-LLM/blame/5fa9436e17c2f9aeace070f49aa645d2577f676b/cpp/tensorrt_llm/common/quantTypeUtils.cuh#L47

[[Image: https://lh7-us.googleusercontent.com/Kvwseo5yqT6gwM4WK-VQ7oqF47xDSqko1HL7X326TvbQRU-W0Evf0mhoB5wKVw4N3ZO8x7ZpGKzOcFNvCgCZE-SwRyKDu9FjjQQXpCeoOtgGySfpBnPynDeOodfyOH0glBU6uDkavGo170iu32CK-vY | width=624.00px | height=126.67px | margin-left=0.00px | margin-top=0.00px | transform=rotate(0.00rad) translateZ(0px) | ]]

// Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720

![](7c7e600d-7c40-4ab4-9928-42f7a75d339aScreenshot_2024-08-01_at_11.26.31_PM.png)

https://developer.nvidia.com/blog/supercharging-llama-3-1-across-nvidia-platforms/

> For the Llama 3.1-405B model, TensorRT-LLM has added support for FP8 quantization at a row-wise granularity level. This involves calculating a static scaling factor for each output weight channel (before execution) and a dynamic scaling factor for each token (during execution) to preserve maximum accuracy.
> During the TensorRT engine build process, some complex layer fusions cannot be automatically discovered. TensorRT-LLM optimizes these using plugins that are explicitly inserted into the network graph definition at compile time to replace user-defined kernels such as the matrix multiplications from **FBGEMM** for the Llama 3.1 models.
> For ease of use and deployment, TensorRT-Model-Optimizer and TensorRT-LLM optimizations are bundled together into NVIDIA NIM inference microservices.

0 comments on commit 91613a0

Please sign in to comment.