Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KTO Loss #475

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ site/
.venv/
venv/
.ipynb_checkpoints/
.vscode/

# Misc
.DS_Store
Expand Down
30 changes: 30 additions & 0 deletions benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Benchmarking Liger Kernels

Follow these steps to benchmark and visualize kernel performance:

1. Create a benchmark script
- Add your script under `benchmark/scripts/`
- Name it according to the kernel (e.g., `benchmark_<kernel_name>.py`)

2. Run the benchmark
- Results will be saved to `benchmark/data/all_benchmark_data.csv`

Example: Benchmarking KTO Loss
```bash
cd benchmark
python scripts/benchmark_kto_loss.py
```

3. Visualize results
- Use the visualization script with appropriate parameters

Example: Visualizing KTO Loss benchmark results
```bash
python benchmarks_visualizer.py \
--kernel-name kto_loss \
--metric-name memory \
--kernel-operation-mode full
```

4. View results
- Generated plots will be saved in `benchmark/visualizations/`
30 changes: 30 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,33 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,8.2532958984375,8.235372543334961,8.274937629699707,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,16.888959884643555,16.879615783691406,16.898893356323242,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,32.13854217529297,32.12795639038086,32.149131774902344,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,64.81161499023438,64.81161499023438,64.81161499023438,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,128.68646240234375,128.68646240234375,128.68646240234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.146656036376953,7.143622398376465,7.152345657348633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,12.538240432739258,12.521356582641602,12.540371894836426,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,26.29542350769043,25.303590774536133,26.88591957092285,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,49.26508712768555,49.26508712768555,49.26508712768555,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,98.9525146484375,98.9525146484375,98.9525146484375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),2,9.005151748657227,8.97766399383545,9.046483039855957,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),4,19.108863830566406,19.09713363647461,19.185260772705078,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.80137634277344,32.775360107421875,32.827388763427734,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),16,65.46678161621094,65.46678161621094,65.46678161621094,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),32,129.91734313964844,129.91734313964844,129.91734313964844,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,16.091487884521484,14.86076831817627,16.23084831237793,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,28.04204750061035,28.03957176208496,28.055641174316406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,54.70073699951172,54.70073699951172,54.70073699951172,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,108.09929656982422,108.09929656982422,108.09929656982422,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,215.1945343017578,215.1945343017578,215.1945343017578,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),2,3037.75390625,3037.75390625,3037.75390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3800.0126953125,3800.0126953125,3800.0126953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),8,4565.28076171875,4565.28076171875,4565.28076171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),16,4589.31787109375,4589.31787109375,4589.31787109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),32,4637.39208984375,4637.39208984375,4637.39208984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4793.7626953125,4793.7626953125,4793.7626953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6551.2978515625,6551.2978515625,6551.2978515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,10063.3681640625,10063.3681640625,10063.3681640625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,17093.5078125,17093.5078125,17093.5078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,31153.7890625,31153.7890625,31153.7890625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
264 changes: 264 additions & 0 deletions benchmark/scripts/benchmark_kto_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import os
import sys

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
from liger_kernel.utils import infer_device

device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchKTOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
from test.chunked_loss.test_kto_loss import HFKTOLoss

super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ref_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.kto_loss = HFKTOLoss(
ignore_index=ignore_index, beta=beta, use_ref_model=True
).get_batch_loss_metrics

def forward(self, x, ref_x, y):
return self.kto_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)[0]


class LigerKTOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ref_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.kto_loss = LigerFusedLinearKTOLoss(
ignore_index=ignore_index, beta=beta, use_ref_model=True
)

def forward(self, x, ref_x, y):
return self.kto_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)[0]


def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

torch_kto_loss = TorchKTOLoss(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)

liger_kto_loss = LigerKTOLoss(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

# Add ignore_index tokens to simulate padding
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)

def fwd():
if provider == "liger":
return liger_kto_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_kto_loss(_input, ref_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_kto_loss = TorchKTOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
liger_kto_loss = LigerKTOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)

# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)

# Add ignore_index tokens
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)

def fwd():
if provider == "liger":
return liger_kto_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_kto_loss(_input, ref_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "kto_loss",
"x_name": "B",
"x_label": "Batch Size (B)",
"x_values": [2**i for i in range(1, 6)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 512,
"H": 1024,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": True,
"beta": 0.1,
"ignore_index": 42,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_kto_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)

run_benchmarks(
bench_test_fn=bench_memory_kto_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Liger FlexChunkLoss: Alignment and Distillation loss

Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.

### User interface

Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/chunked_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
Loading