Skip to content

Commit

Permalink
Add Chunked SimPO Loss (#386)
Browse files Browse the repository at this point in the history
## Summary
This PR adds the Simple Preference Optimization Loss function. The only
difference between SimPO and CPO is a margin term `gamma` which
specifies that the preferred response should be atleast gamma logit
points better than the losing response.

$$SimPOLoss = -\log(\sigma(\beta\log(\pi_\theta(y_c|x)) -
\beta\log(\pi_\theta(y_r|x)) - \gamma))$$

Note that SimPO explicitly specifies that $$\pi_\theta(y|x)$$ needs to
be normalized by length, unlike DPO.

This corresponds to Eq 6 in the
[paper](https://arxiv.org/pdf/2405.14734).

## Testing Done
GPU A100-80G-SXM

![Screenshot 2024-11-15 at 2 38
23 PM](https://github.com/user-attachments/assets/ac126f94-ebd8-4457-a4a2-53832699af4c)
![Screenshot 2024-11-15 at 2 38
37 PM](https://github.com/user-attachments/assets/e539e9cd-f66a-42dd-8b43-3ae44dcd42a0)


- Hardware Type: <BLANK>
- [X] run `make test` to ensure correctness
- [X] run `make checkstyle` to ensure code style
- [X] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
  • Loading branch information
pramodith and ByronHsu authored Nov 19, 2024
1 parent 11ec97b commit ebd5303
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 8 deletions.
24 changes: 24 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -691,3 +691,27 @@ fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.31445
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,2,30.28438377380371,30.107013702392578,30.284786224365234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,4,58.80876922607422,58.80876922607422,58.80876922607422,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,8,117.96163177490234,117.96163177490234,117.96163177490234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,16,235.60794067382812,235.60794067382812,235.60794067382812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,2,14.513839721679688,14.510687828063965,14.517855644226074,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,4,28.78099250793457,28.72719383239746,28.792186737060547,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,8,52.5733757019043,52.5733757019043,52.5733757019043,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,16,104.44764709472656,104.44764709472656,104.44764709472656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,2,31.566062927246094,31.457612991333008,31.674514770507812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,4,61.4403190612793,61.4403190612793,61.4403190612793,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,8,119.97705841064453,119.97705841064453,119.97705841064453,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,16,238.13417053222656,238.13417053222656,238.13417053222656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,2,39.811119079589844,39.65474319458008,39.96749496459961,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,4,77.20928192138672,77.20928192138672,77.20928192138672,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,8,153.6952667236328,153.6952667236328,153.6952667236328,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,16,307.7382507324219,307.7382507324219,307.7382507324219,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,2,7675.3291015625,7675.3291015625,7675.3291015625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,4,7723.3447265625,7723.3447265625,7723.3447265625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,8,7819.3759765625,7819.3759765625,7819.3759765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,16,8011.4384765625,8011.4384765625,8011.4384765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
191 changes: 191 additions & 0 deletions benchmark/scripts/benchmark_simpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
import sys

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchLMHeadSimPO(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.
:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
from test.chunked_loss.test_cpo_loss import HFCPOLoss

super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics

def forward(self, x, y):
return self.simpo_loss(x, self.lin.weight, y)


class LigerLMHeadSimPO(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.simpo_loss = LigerFusedLinearSimPOFunction.apply

def forward(self, x, y):
return self.simpo_loss(x, self.lin.weight, y)


#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################


def bench_memory_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_simpo(_input, target)
elif provider == "huggingface":
return torch_lm_head_simpo(_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################


def bench_speed_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_simpo(_input, target)
elif provider == "huggingface":
return torch_lm_head_simpo(_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "fused_linear_simpo_loss",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_fused_linear_simpo_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_simpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
8 changes: 6 additions & 2 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def forward(
alpha=1.0,
beta=0.1,
compiled=True,
**loss_kwargs,
):
"""
Base class for fused linear layer with preference loss.
Expand All @@ -49,6 +50,7 @@ def forward(
alpha (float): Weight for the NLL loss.
beta (float): Weight for the odds ratio loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
CHUNK_SIZE = chunk_size
Expand All @@ -68,6 +70,7 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
full_target=target,
**loss_kwargs,
)

def accumulate_chunk(input_chunk, target_chunk):
Expand All @@ -94,6 +97,9 @@ def accumulate_chunk(input_chunk, target_chunk):
loss_acc.add_(chunk_loss)
return chunk_grad_input

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

len_chosen = target.shape[0] // 2
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
Expand All @@ -116,8 +122,6 @@ def accumulate_chunk(input_chunk, target_chunk):
[chosen_target_chunk, rejected_target_chunk], dim=0
)

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)
grad_input = accumulate_chunk(input_chunk, target_chunk)

grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compiled=True,
compiled=False,
):
"""
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
Expand Down
64 changes: 64 additions & 0 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
LigerFusedLinearPreferenceBase,
)


class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5):
"""
Compute odds-ratio loss.
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the odds ratio loss.
gamma (float): The simpo gamma, margin term.
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = F.logsigmoid(logits).mean()
return loss

@staticmethod
def forward(
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
beta=0.1,
alpha=1.0,
compute_nll_loss=False,
compiled=True,
gamma=0.5,
):
"""
Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
Handles both the forward and backward pass of the final linear layer with SimPO loss.
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
"""

return LigerFusedLinearPreferenceBase.forward(
ctx,
_input,
weight,
target,
bias,
loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
compiled=compiled,
gamma=gamma,
)

@staticmethod
def backward(ctx, grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None, None, None
12 changes: 10 additions & 2 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def __init__(
beta: float = 0.1,
ignore_index: int = -100,
label_smoothing: float = 0.0,
simpo_gamma: float = 0.5,
loss_type: str = "sigmoid",
):
super().__init__(alpha=alpha, beta=beta, ignore_index=ignore_index)
# Sigmoid defaults to the CPO loss defined in the paper listed above.
self.loss_type = "sigmoid"
self.loss_type = loss_type
self.label_smoothing = label_smoothing
self.simpo_gamma = simpo_gamma

def alignment_loss(
self,
Expand Down Expand Up @@ -55,6 +58,12 @@ def alignment_loss(
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "simpo":
logits = logits - (self.simpo_gamma / self.beta)
losses = (
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']"
Expand All @@ -66,7 +75,6 @@ def alignment_loss(
@pytest.mark.parametrize(
"B, T, H, V",
[
# (1, 2, 12, 128),
(8, 128, 1024, 4096),
(3, 47, 31, 123), # random shape
],
Expand Down
Loading

0 comments on commit ebd5303

Please sign in to comment.