diff --git a/.gitignore b/.gitignore
index 8b4455fc2..e643ae280 100644
--- a/.gitignore
+++ b/.gitignore
@@ -10,4 +10,7 @@ site/
# Build
build/
-dist/
\ No newline at end of file
+dist/
+
+# Benchmark images
+benchmark/visualizations
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 2ea925a59..8c40c5781 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -55,7 +55,7 @@ The `/benchmark` directory contains benchmarking scripts for the individual kern
- Existing entries that are the same (based on `kernel_name`, `kernel_provider`, `kernel_operation_mode`, `metric_name`, `x_name`, `x_value`, `extra_benchmark_config_str`, and `gpu_name`) will not be overwritten.
2. Run `make run-benchmarks OVERWRITE=1` to overwrite any existing entries that have the same configuration.
3. Run `python benchmark/scripts/benchmark_{kernel_name}.py` to run an individual benchmark.
-4. You can use the `benchmark/benchmarks_visualizer.ipynb` notebook as an example to load the CSV and perform data visualization/analysis.
+4. You can use the `benchmark/benchmarks_visualizer.py` script to generate visualizations from the CSV, these are then saved to the `benchmark/visualizations` directory (note: this directory is not tracked by git).
## Submit PR
Fork the repo, copy and paste the successful test logs in the PR and submit the PR followed by the PR template (**[example PR](https://github.com/linkedin/Liger-Kernel/pull/21)**).
diff --git a/README.md b/README.md
index 5e092f0cc..6d11f99c8 100644
--- a/README.md
+++ b/README.md
@@ -40,18 +40,19 @@
-[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing) | [Contact](#contact)
+[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing) | [Acknowledgement](#acknowledgement)
Latest News 🔥
+ - [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
- [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
- [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)
-**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduce **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
+**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
## Supercharge Your Model with Liger Kernel
@@ -131,7 +132,7 @@ pip install -e .
```
## Getting Started
-There are a couple ways to apply Liger kernels, depending on the level of customization required.
+There are a couple of ways to apply Liger kernels, depending on the level of customization required.
### 1. Use AutoLigerKernelForCausalLM
@@ -241,6 +242,7 @@ loss.backward()
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
+| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
@@ -254,7 +256,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
-
+- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
### Experimental Kernels
@@ -290,12 +292,30 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil
## Acknowledgement
+
+### Design
+
- [@claire_yishan](https://twitter.com/claire_yishan) for the LOGO design
-- [flash-attn](https://github.com/Dao-AILab/flash-attention) and [Unsloth](https://github.com/unslothai/unsloth) for inspiration in Triton kernels for training
-- [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) by Andrej Karpathy for convergence testing
-- [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy) for lm_head + cross entropy inspiration
- [Wave Snippets](https://www.wavesnippets.com/) for generating the animated code snippets
+### Code
+
+We referenced or used the following projects:
+
+
+
+| # | Project | Description | Location | License |
+|---|----------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------|
+| 1 | [Unsloth](https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43) | `calculate_settings` to determine block size and warp; We reuse it for Norm and MLP | [Liger Kernel Utils](https://github.com/linkedin/Liger-Kernel/blob/e249eee723978bf8610ff1ea2297d048a2417e20/src/liger_kernel/ops/utils.py#L23) | [Apache](https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/LICENSE) |
+| 2 | [Unsloth](https://github.com/unslothai/unsloth/blob/976d11a10d54383aeb7a692c69e01151a20bfd72/unsloth/kernels/rms_layernorm.py#L48) | We modified and added dW calculation on top of Unsloth implementation | [Liger Kernel RMS Norm](https://github.com/linkedin/Liger-Kernel/blob/e249eee723978bf8610ff1ea2297d048a2417e20/src/liger_kernel/ops/rms_norm.py#L50) | [Apache](https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/LICENSE) |
+| 3 | [Triton tutorial](https://triton-lang.org/main/index.html) | We modified on top of triton tutorials | [Liger Kernel RMS Norm](https://github.com/linkedin/Liger-Kernel/blob/e249eee723978bf8610ff1ea2297d048a2417e20/src/liger_kernel/ops/rms_norm.py#L50) | [MIT](https://github.com/triton-lang/triton/blob/main/LICENSE) |
+| 4 | [tiny shakespeare dataset](https://huggingface.co/datasets/karpathy/tiny_shakespeare) | We use tiny shakespeare dataset to conduct convergence test on mini model | [Liger Kernel Convergence](https://github.com/linkedin/Liger-Kernel/tree/main/test/convergence) | N/A |
+| 5 | [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy) | We use the idea of gradient-in-forward and chunking | [Liger Kernel Linear Cross Entropy](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/fused_linear_cross_entropy.py) | [MIT](https://github.com/mgmalek/efficient_cross_entropy/blob/main/LICENSE) |
+| 6 | [Flash attn](https://github.com/Dao-AILab/flash-attention) | We take many optimization ideas from the work, such as tiling and recomputation | | [BSD](https://github.com/Dao-AILab/flash-attention/blob/main/LICENSE) |
+| 7 | [AutoAWQ](https://github.com/casper-hansen/AutoAWQ) | We reference the design of automodel | [Liger Kernel Auto Model](https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/auto_model.py) | [MIT](https://github.com/casper-hansen/AutoAWQ/blob/main/LICENSE) |
+| 8 | [llm.c](https://github.com/karpathy/llm.c) | We reference the design of end-to-end testing | [Liger Kernel Convergence Tests](https://github.com/linkedin/Liger-Kernel/tree/main/test/convergence) | [MIT](https://github.com/karpathy/llm.c/blob/master/LICENSE) |
+
+Many thanks to the contributors to these projects for their invaluable work that helped make Liger possible.
## License
diff --git a/benchmark/benchmarks_visualizer.ipynb b/benchmark/benchmarks_visualizer.ipynb
deleted file mode 100644
index 7ad9461fd..000000000
--- a/benchmark/benchmarks_visualizer.ipynb
+++ /dev/null
@@ -1,132 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7b9a1150",
- "metadata": {
- "metadata": {}
- },
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "import matplotlib.pyplot as plt\n",
- "import seaborn as sns\n",
- "import json"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1baa5a05",
- "metadata": {
- "metadata": {}
- },
- "outputs": [],
- "source": [
- "data = pd.read_csv(\"data/all_benchmark_data.csv\")\n",
- "data[\"extra_benchmark_config\"] = data[\"extra_benchmark_config_str\"].apply(json.loads)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "88d54e77",
- "metadata": {
- "metadata": {}
- },
- "outputs": [],
- "source": [
- "kernel_name = \"kl_div\"\n",
- "metric_name = \"speed\"\n",
- "kernel_operation_mode = \"full\"\n",
- "\n",
- "filtered_data = data[\n",
- " (data[\"kernel_name\"] == kernel_name)\n",
- " & (data[\"metric_name\"] == metric_name)\n",
- " & (data[\"kernel_operation_mode\"] == kernel_operation_mode)\n",
- " # Use this to filter by extra benchmark configuration property\n",
- " # & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096))\n",
- "]\n",
- "\n",
- "print(filtered_data)\n",
- "if len(filtered_data) == 0:\n",
- " raise ValueError(\"No data found for the specified filter\")\n",
- "\n",
- "xlabel = filtered_data[\"x_label\"].iloc[0]\n",
- "ylabel = f\"{metric_name} ({filtered_data['metric_unit'].iloc[0]})\"\n",
- "# Sort by \"kernel_provider\" to ensure consistent color assignment\n",
- "filtered_data = filtered_data.sort_values(by=\"kernel_provider\")\n",
- "\n",
- "plt.figure(figsize=(10, 6))\n",
- "sns.set(style=\"whitegrid\")\n",
- "ax = sns.lineplot(\n",
- " data=filtered_data,\n",
- " x=\"x_value\",\n",
- " y=\"y_value_50\",\n",
- " hue=\"kernel_provider\",\n",
- " marker=\"o\",\n",
- " palette=\"tab10\",\n",
- " errorbar=(\"ci\", None),\n",
- ")\n",
- "\n",
- "# Seaborn can't plot pre-computed error bars, so we need to do it manually\n",
- "lines = ax.get_lines()\n",
- "colors = [line.get_color() for line in lines]\n",
- "\n",
- "for (kernel_provider, group_data), color in zip(\n",
- " filtered_data.groupby(\"kernel_provider\"), colors\n",
- "):\n",
- " # for i, row in group_data.iterrows():\n",
- " y_error_lower = group_data[\"y_value_50\"] - group_data[\"y_value_20\"]\n",
- " y_error_upper = group_data[\"y_value_80\"] - group_data[\"y_value_50\"]\n",
- " y_error = [y_error_lower, y_error_upper]\n",
- "\n",
- " plt.errorbar(\n",
- " group_data[\"x_value\"],\n",
- " group_data[\"y_value_50\"],\n",
- " yerr=y_error,\n",
- " fmt=\"o\",\n",
- " color=color,\n",
- " capsize=5,\n",
- " )\n",
- "plt.legend(title=\"Kernel Provider\")\n",
- "plt.xlabel(xlabel)\n",
- "plt.ylabel(ylabel)\n",
- "plt.tight_layout()\n",
- "\n",
- "print(\"Summary of filtered data found:\")\n",
- "print(filtered_data.describe(include=\"all\"))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "372c1ae1",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.12"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/benchmark/benchmarks_visualizer.py b/benchmark/benchmarks_visualizer.py
new file mode 100644
index 000000000..2cb9b1330
--- /dev/null
+++ b/benchmark/benchmarks_visualizer.py
@@ -0,0 +1,169 @@
+import json
+import os
+from argparse import ArgumentParser
+from dataclasses import dataclass
+
+import matplotlib.pyplot as plt
+import pandas as pd
+import seaborn as sns
+
+DATA_PATH = "data/all_benchmark_data.csv"
+VISUALIZATIONS_PATH = "visualizations/"
+
+
+@dataclass
+class VisualizationsConfig:
+ """
+ Configuration for the visualizations script.
+
+ Args:
+ kernel_name (str): Kernel name to benchmark. (Will run `scripts/benchmark_{kernel_name}.py`)
+ metric_name (str): Metric name to visualize (speed/memory)
+ kernel_operation_mode (str): Kernel operation mode to visualize (forward/backward/full). Defaults to "full"
+ display (bool): Display the visualization. Defaults to False
+ overwrite (bool): Overwrite existing visualization, if none exist this flag has no effect as ones are always created and saved. Defaults to False
+
+ """
+
+ kernel_name: str
+ metric_name: str
+ kernel_operation_mode: str = "full"
+ display: bool = False
+ overwrite: bool = False
+
+
+def parse_args() -> VisualizationsConfig:
+ """Parse command line arguments into a configuration object.
+
+ Returns:
+ VisualizationsConfig: Configuration object for the visualizations script.
+ """
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--kernel-name", type=str, required=True, help="Kernel name to benchmark"
+ )
+ parser.add_argument(
+ "--metric-name",
+ type=str,
+ required=True,
+ help="Metric name to visualize (speed/memory)",
+ )
+ parser.add_argument(
+ "--kernel-operation-mode",
+ type=str,
+ required=True,
+ help="Kernel operation mode to visualize (forward/backward/full)",
+ )
+ parser.add_argument(
+ "--display", action="store_true", help="Display the visualization"
+ )
+ parser.add_argument(
+ "--overwrite",
+ action="store_true",
+ help="Overwrite existing visualization, if none exist this flag has no effect as one are always created",
+ )
+
+ args = parser.parse_args()
+
+ return VisualizationsConfig(**dict(args._get_kwargs()))
+
+
+def load_data(config: VisualizationsConfig) -> pd.DataFrame:
+ """Loads the benchmark data from the CSV file and filters it based on the configuration.
+
+ Args:
+ config (VisualizationsConfig): Configuration object for the visualizations script.
+
+ Raises:
+ ValueError: If no data is found for the given filters.
+
+ Returns:
+ pd.DataFrame: Filtered benchmark dataframe.
+ """
+ df = pd.read_csv(DATA_PATH)
+ df["extra_benchmark_config"] = df["extra_benchmark_config_str"].apply(json.loads)
+
+ filtered_df = df[
+ (df["kernel_name"] == config.kernel_name)
+ & (df["metric_name"] == config.metric_name)
+ & (df["kernel_operation_mode"] == config.kernel_operation_mode)
+ # Use this to filter by extra benchmark configuration property
+ # & (data['extra_benchmark_config'].apply(lambda x: x.get('H') == 4096))
+ # FIXME: maybe add a way to filter using some configuration, except of hardcoding it
+ ]
+
+ if filtered_df.empty:
+ raise ValueError("No data found for the given filters")
+
+ return filtered_df
+
+
+def plot_data(df: pd.DataFrame, config: VisualizationsConfig):
+ """Plots the benchmark data, saving the result if needed.
+
+ Args:
+ df (pd.DataFrame): Filtered benchmark dataframe.
+ config (VisualizationsConfig): Configuration object for the visualizations script.
+ """
+ xlabel = df["x_label"].iloc[0]
+ ylabel = f"{config.metric_name} ({df['metric_unit'].iloc[0]})"
+ # Sort by "kernel_provider" to ensure consistent color assignment
+ df = df.sort_values(by="kernel_provider")
+
+ plt.figure(figsize=(10, 6))
+ sns.set(style="whitegrid")
+ ax = sns.lineplot(
+ data=df,
+ x="x_value",
+ y="y_value_50",
+ hue="kernel_provider",
+ marker="o",
+ palette="tab10",
+ errorbar=("ci", None),
+ )
+
+ # Seaborn can't plot pre-computed error bars, so we need to do it manually
+ lines = ax.get_lines()
+ colors = [line.get_color() for line in lines]
+
+ for (_, group_data), color in zip(df.groupby("kernel_provider"), colors):
+ # for i, row in group_data.iterrows():
+ y_error_lower = group_data["y_value_50"] - group_data["y_value_20"]
+ y_error_upper = group_data["y_value_80"] - group_data["y_value_50"]
+ y_error = [y_error_lower, y_error_upper]
+
+ plt.errorbar(
+ group_data["x_value"],
+ group_data["y_value_50"],
+ yerr=y_error,
+ fmt="o",
+ color=color,
+ capsize=5,
+ )
+ plt.legend(title="Kernel Provider")
+ plt.xlabel(xlabel)
+ plt.ylabel(ylabel)
+ plt.tight_layout()
+
+ out_path = os.path.join(
+ VISUALIZATIONS_PATH, f"{config.kernel_name}_{config.metric_name}.png"
+ )
+
+ if config.display:
+ plt.show()
+ if config.overwrite or not os.path.exists(
+ out_path
+ ): # Save the plot if it doesn't exist or if we want to overwrite it
+ os.makedirs(VISUALIZATIONS_PATH, exist_ok=True)
+ plt.savefig(out_path)
+ plt.close()
+
+
+def main():
+ config = parse_args()
+ df = load_data(config)
+ plot_data(df, config)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/huggingface/run_benchmarks.sh b/examples/huggingface/run_benchmarks.sh
new file mode 100755
index 000000000..f6df505bb
--- /dev/null
+++ b/examples/huggingface/run_benchmarks.sh
@@ -0,0 +1,50 @@
+## Benchmarking Script
+## Runs the training script with different configurations and logs the results
+
+MODEL_TYPE="mistral"
+MODEL_PATH="mistralai/Mistral-7B-v0.1"
+USE_LIGER_VALUES=("True" "False")
+BATCH_SIZE_VALUES=(64 128 192)
+NUM_REP=5
+MAX_STEPS=20
+DATASET_PATH="tatsu-lab/alpaca"
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+mkdir -p "${SCRIPT_DIR}/results"
+
+for USE_LIGER in "${USE_LIGER_VALUES[@]}"; do
+ for BATCH_SIZE in "${BATCH_SIZE_VALUES[@]}"; do
+ echo "Running with use_liger=$USE_LIGER and batch_size=$BATCH_SIZE"
+
+ for ((i=1; i<=NUM_REP; i++)); do
+
+ LOG_FILE="${SCRIPT_DIR}/results/${MODEL_TYPE}_use_liger_${USE_LIGER}_batch_size_${BATCH_SIZE}_rep_${i}.log"
+
+ torchrun --nnodes=1 --nproc-per-node=4 training.py \
+ --bf16 \
+ --num_train_epochs 1 \
+ --max_steps $MAX_STEPS \
+ --model_name $MODEL_PATH \
+ --dataset $DATASET_PATH \
+ --per_device_train_batch_size $BATCH_SIZE \
+ --per_device_eval_batch_size 16 \
+ --eval_strategy "no" \
+ --save_strategy "no" \
+ --learning_rate 6e-6 \
+ --weight_decay 0.05 \
+ --warmup_ratio 0.1 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --include_num_input_tokens_seen \
+ --report_to none \
+ --fsdp "full_shard auto_wrap" \
+ --fsdp_config config/fsdp_config.json \
+ --seed 42 \
+ --use_liger $USE_LIGER \
+ --output_dir model_output_dir \
+ > $LOG_FILE
+
+ sleep 5
+ done
+ done
+done
\ No newline at end of file
diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py
index 32703788c..901809d4d 100644
--- a/src/liger_kernel/ops/cross_entropy.py
+++ b/src/liger_kernel/ops/cross_entropy.py
@@ -14,6 +14,7 @@ def liger_cross_entropy_kernel(
n_cols,
n_non_ignore,
ignore_index,
+ label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
@@ -30,6 +31,7 @@ def liger_cross_entropy_kernel(
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
ignore_index (int): The index to ignore in the target.
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""
@@ -63,12 +65,20 @@ def liger_cross_entropy_kernel(
X_ptr + y
) # we need to store the original value of X_y for the loss calculation
+ # Label smoothing is a general case of normal cross entropy
+ # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
+ scaled_x_sum = 0.0
+ eps = label_smoothing / n_cols
+
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
block_max = tl.max(X_block)
+ if label_smoothing > 0:
+ # scale X beforehand to avoid overflow
+ scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
@@ -77,12 +87,16 @@ def liger_cross_entropy_kernel(
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
+ # For label smoothing:
+ # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
+ # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
+ # = dx_i - (1 - label_smoothing) / N
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
- X_block = (tl.exp(X_block - m) / d) / (n_non_ignore)
+ X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
@@ -97,9 +111,21 @@ def liger_cross_entropy_kernel(
# So we can safely calculate log (softmax(X_y)) without overflow
loss = -(ori_X_y - m - tl.log(d))
- # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N`
+ # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
+ # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
+ # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
+ # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
+ # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
+ # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
+ # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
+ # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
+ if label_smoothing > 0:
+ smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
+ loss = loss * (1 - label_smoothing) + smooth_loss
+
+ # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
- X_y += -1 / (n_non_ignore)
+ X_y += -(1 - label_smoothing) / (n_non_ignore)
tl.store(loss_ptr, loss)
tl.store(X_ptr + y, X_y)
@@ -147,7 +173,7 @@ def element_mul_kernel(
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
-def cross_entropy_forward(_input, target, ignore_index):
+def cross_entropy_forward(_input, target, ignore_index, label_smoothing):
BT, V = _input.shape
n_rows = BT
@@ -175,6 +201,7 @@ def cross_entropy_forward(_input, target, ignore_index):
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
+ label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
@@ -216,7 +243,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
"""
@staticmethod
- def forward(ctx, _input, target, ignore_index):
+ def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0):
"""
The forward pass of the Liger Cross Entropy loss.
@@ -225,11 +252,14 @@ def forward(ctx, _input, target, ignore_index):
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
ignore_index (int): The index to ignore in the target.
+ label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Returns:
tensor: The computed loss.
"""
- loss, _input = cross_entropy_forward(_input, target, ignore_index)
+ loss, _input = cross_entropy_forward(
+ _input, target, ignore_index, label_smoothing
+ )
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
@@ -254,4 +284,5 @@ def backward(ctx, grad_output):
_input,
None,
None,
+ None,
)
diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py
index 1164abee8..bf0e3da48 100644
--- a/src/liger_kernel/ops/fused_linear_cross_entropy.py
+++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py
@@ -5,34 +5,16 @@
element_mul_kernel,
liger_cross_entropy_kernel,
)
-from liger_kernel.ops.utils import get_torch_activation
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 65536 // 2
-LOGIT_SOFTCAP_VAL = "softcap_value"
-LOGIT_SOFTCAP_ACT = "softcap_act"
def fused_linear_cross_entropy_forward(
- _input, weight, target, bias=None, final_logit_softcap_params=None, ignore_index=-100
+ _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
):
- if final_logit_softcap_params is not None:
- if {LOGIT_SOFTCAP_VAL, LOGIT_SOFTCAP_ACT} != set(
- final_logit_softcap_params.keys()
- ):
- raise Exception(
- f"final_logit_softcap_params should be a Dict with two keys {LOGIT_SOFTCAP_VAL}, {LOGIT_SOFTCAP_ACT}"
- )
- final_logit_softcap_params.update(
- {
- LOGIT_SOFTCAP_ACT: get_torch_activation(
- final_logit_softcap_params.get(LOGIT_SOFTCAP_ACT)
- )
- }
- )
-
dtype = (
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
)
@@ -72,16 +54,6 @@ def fused_linear_cross_entropy_forward(
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
if bias is not None:
logits_chunk = logits_chunk + bias
- if final_logit_softcap_params is not None:
- logits_chunk = logits_chunk / final_logit_softcap_params.get(
- LOGIT_SOFTCAP_VAL
- )
- logits_chunk = final_logit_softcap_params.get(LOGIT_SOFTCAP_ACT)(
- logits_chunk
- )
- logits_chunk = logits_chunk * final_logit_softcap_params.get(
- LOGIT_SOFTCAP_VAL
- )
target_chunk = target[start_idx:end_idx] # chunk_size,
n_rows = logits_chunk.shape[0]
@@ -108,6 +80,7 @@ def fused_linear_cross_entropy_forward(
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
+ label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
@@ -200,13 +173,7 @@ def fused_linear_cross_entropy_backward(
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(
- ctx,
- _input,
- weight,
- target,
- bias=None,
- final_logit_softcap_params=None,
- ignore_index=-100,
+ ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
):
"""
Fusing the last linear layer with cross-entropy loss
@@ -221,11 +188,10 @@ def forward(
target: (B*T) where each value is in [0, V-1]
weight: (V, H) where V is the number of classes
bias: (V) where V is the number of classes
- softcap_params: Dict with two keys: {"softcap_value": , "softcap_act": }
ignore_index: the index to ignore in the target
"""
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
- _input, weight, target, bias, final_logit_softcap_params, ignore_index
+ _input, weight, target, bias, ignore_index, label_smoothing
)
# downcast to dtype and store for backward
ctx.save_for_backward(
@@ -241,4 +207,4 @@ def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
)
- return (grad_input, grad_weight, None, grad_bias, None, None)
+ return (grad_input, grad_weight, None, grad_bias, None)
diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py
index 1a10f5970..68fcf05d2 100644
--- a/src/liger_kernel/ops/rms_norm.py
+++ b/src/liger_kernel/ops/rms_norm.py
@@ -1,3 +1,15 @@
+"""
+This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
+See the original Unsloth repository at https://github.com/unslothai/unsloth.
+
+The following line
+https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/rms_norm.py#L30
+is based on code from Unsloth, located at:
+https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22
+
+Modifications made by Yanning Chen, 2024.
+"""
+
import operator
import torch
diff --git a/src/liger_kernel/ops/utils.py b/src/liger_kernel/ops/utils.py
index 4ce661a6f..0cc5fc3f8 100644
--- a/src/liger_kernel/ops/utils.py
+++ b/src/liger_kernel/ops/utils.py
@@ -1,3 +1,15 @@
+"""
+This file incorporates code from Unsloth licensed under the Apache License, Version 2.0.
+See the original Unsloth repository at https://github.com/unslothai/unsloth.
+
+The following line
+https://github.com/linkedin/Liger-Kernel/blob/7382a8761f9af679482b968f9348013d933947c7/src/liger_kernel/ops/utils.py#L23
+is based on code from Unsloth, located at:
+https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43
+
+Modifications made by Yanning Chen, 2024.
+"""
+
import functools
import importlib
from typing import Callable
diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py
index 44255eb85..0adb1cc87 100644
--- a/src/liger_kernel/transformers/cross_entropy.py
+++ b/src/liger_kernel/transformers/cross_entropy.py
@@ -6,6 +6,11 @@
class LigerCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, *args, **kwargs):
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
+ assert (self.label_smoothing >= 0) and (
+ self.label_smoothing <= 1
+ ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"
def forward(self, _input, target):
- return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index)
+ return LigerCrossEntropyFunction.apply(
+ _input, target, self.ignore_index, self.label_smoothing
+ )
diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py
index c1fcccb34..cdc7000e5 100644
--- a/test/transformers/test_cross_entropy.py
+++ b/test/transformers/test_cross_entropy.py
@@ -61,6 +61,61 @@ def _test_correctness_with_ignore_index_once(
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
+def _test_correctness_with_label_smoothing_once(
+ target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol
+):
+ torch.manual_seed(0)
+ torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing)
+
+ _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar
+ _input = _tensor.detach().clone().requires_grad_(True)
+ _input2 = _tensor.detach().clone().requires_grad_(True)
+
+ target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)
+
+ output = torch_ce(_input, target)
+ output2 = target_ce(_input2, target)
+
+ assert torch.allclose(output, output2, atol=atol, rtol=rtol)
+
+ output.backward()
+ output2.backward()
+ assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
+
+
+def _test_correctness_with_label_smoothing_with_ignore_index_once(
+ target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
+):
+ torch.manual_seed(0)
+ torch_ce = CrossEntropyLoss(
+ ignore_index=ignore_index, label_smoothing=label_smoothing
+ )
+
+ _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar
+ _input = _tensor.detach().clone().requires_grad_(True)
+ _input2 = _tensor.detach().clone().requires_grad_(True)
+
+ target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)
+
+ # Assign some random number of elements as ignore_index
+ num_elements_to_assign = torch.randint(
+ 1, B * T // 2, (1,)
+ ).item() # Random number of elements to set to ignore_index
+ indices_to_assign = torch.randperm(B * T)[
+ :num_elements_to_assign
+ ] # Randomly select indices
+ target[indices_to_assign] = ignore_index
+
+ output = torch_ce(_input, target)
+ output2 = target_ce(_input2, target)
+
+ assert torch.allclose(output, output2, atol=atol, rtol=rtol)
+
+ output.backward()
+ output2.backward()
+ assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)
+
+
def _test_correctness_not_last_layer_once(
target_ce, B, T, V, scalar, dtype, atol, rtol
):
@@ -248,6 +303,125 @@ def test_correctness_with_ignore_index(
)
+@pytest.mark.parametrize(
+ "B, T, V, label_smoothing",
+ [
+ (2, 4096, 32000, 0.1), # llama2, mistral
+ (2, 4096, 32000, 0.1), # llama2, mistral
+ (1, 4096, 128256, 0.1), # llama3
+ # weird shapes
+ (3, 423, 32000, 0.1),
+ ],
+)
+@pytest.mark.parametrize(
+ "scalar, dtype, atol, rtol",
+ [
+ pytest.param(
+ 0.1,
+ torch.bfloat16,
+ 1e-8,
+ 5e-2,
+ marks=pytest.mark.skipif(
+ not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
+ ),
+ ),
+ pytest.param(
+ 1.0,
+ torch.bfloat16,
+ 1e-8,
+ 5e-2,
+ marks=pytest.mark.skipif(
+ not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
+ ),
+ ),
+ pytest.param(
+ 10.0,
+ torch.bfloat16,
+ 1e-8,
+ 5e-2,
+ marks=pytest.mark.skipif(
+ not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
+ ),
+ ),
+ (0.1, torch.float32, 1e-8, 1e-6),
+ (1.0, torch.float32, 1e-8, 1e-6),
+ (10.0, torch.float32, 1e-8, 1e-6),
+ ],
+)
+@pytest.mark.skipif(
+ torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
+ reason="Needs 16GB+ GPU memory.",
+)
+def test_correctness_with_label_smoothing_once(
+ B, T, V, label_smoothing, scalar, dtype, atol, rtol
+):
+ liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing)
+ _test_correctness_with_label_smoothing_once(
+ liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol
+ )
+
+
+@pytest.mark.parametrize(
+ "B, T, V, ignore_index, label_smoothing",
+ [
+ (2, 4096, 32000, 1, 0.1), # llama2, mistral
+ (2, 4096, 32000, -100, 0.2), # llama2, mistral
+ (1, 4096, 128256, 2, 0.1), # llama3
+ # weird shapes
+ (3, 423, 32000, -300, 0.2),
+ ],
+)
+@pytest.mark.parametrize(
+ "scalar, dtype, atol, rtol",
+ [
+ pytest.param(
+ 0.1,
+ torch.bfloat16,
+ 1e-8,
+ 5e-2,
+ marks=pytest.mark.skipif(
+ not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
+ ),
+ ),
+ pytest.param(
+ 1.0,
+ torch.bfloat16,
+ 1e-8,
+ 5e-2,
+ marks=pytest.mark.skipif(
+ not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
+ ),
+ ),
+ pytest.param(
+ 10.0,
+ torch.bfloat16,
+ 1e-8,
+ 5e-2,
+ marks=pytest.mark.skipif(
+ not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
+ ),
+ ),
+ (0.1, torch.float32, 1e-8, 1e-6),
+ (1.0, torch.float32, 1e-8, 1e-6),
+ (10.0, torch.float32, 1e-8, 1e-6),
+ ],
+)
+@pytest.mark.skipif(
+ torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
+ reason="Needs 16GB+ GPU memory.",
+)
+def test_correctness_with_label_smoothing_with_ignore_index_once(
+ B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
+):
+ liger_ce = LigerCrossEntropyLoss(
+ ignore_index=ignore_index,
+ label_smoothing=label_smoothing,
+ )
+ _test_correctness_with_label_smoothing_with_ignore_index_once(
+ liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
+ )
+
+
@pytest.mark.parametrize(
"B, T, V",
[