-
Notifications
You must be signed in to change notification settings - Fork 230
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Summary This PR adds the Simple Preference Optimization Loss function. The only difference between SimPO and CPO is a margin term `gamma` which specifies that the preferred response should be atleast gamma logit points better than the losing response. $$SimPOLoss = -\log(\sigma(\beta\log(\pi_\theta(y_c|x)) - \beta\log(\pi_\theta(y_r|x)) - \gamma))$$ Note that SimPO explicitly specifies that $$\pi_\theta(y|x)$$ needs to be normalized by length, unlike DPO. This corresponds to Eq 6 in the [paper](https://arxiv.org/pdf/2405.14734). ## Testing Done GPU A100-80G-SXM ![Screenshot 2024-11-15 at 2 38 23 PM](https://github.com/user-attachments/assets/ac126f94-ebd8-4457-a4a2-53832699af4c) ![Screenshot 2024-11-15 at 2 38 37 PM](https://github.com/user-attachments/assets/e539e9cd-f66a-42dd-8b43-3ae44dcd42a0) - Hardware Type: <BLANK> - [X] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [X] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
- Loading branch information
Showing
8 changed files
with
381 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
import triton | ||
from utils import ( | ||
QUANTILES, | ||
SingleBenchmarkRunInput, | ||
SingleBenchmarkRunOutput, | ||
_test_memory, | ||
parse_benchmark_script_args, | ||
run_benchmarks, | ||
) | ||
|
||
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction | ||
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | ||
|
||
|
||
class TorchLMHeadSimPO(torch.nn.Module): | ||
"""Ground truth implementation of the linear fused with torch based cross entropy loss. | ||
:param H: hidden size | ||
:param V: vocab size | ||
:param ignore_index: index to ignore | ||
:param reduction: reduction method | ||
""" | ||
|
||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | ||
from test.chunked_loss.test_cpo_loss import HFCPOLoss | ||
|
||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=False, dtype=dtype | ||
) | ||
self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics | ||
|
||
def forward(self, x, y): | ||
return self.simpo_loss(x, self.lin.weight, y) | ||
|
||
|
||
class LigerLMHeadSimPO(torch.nn.Module): | ||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | ||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=False, dtype=dtype | ||
) | ||
self.simpo_loss = LigerFusedLinearSimPOFunction.apply | ||
|
||
def forward(self, x, y): | ||
return self.simpo_loss(x, self.lin.weight, y) | ||
|
||
|
||
############################################################################# | ||
# Test the memory consumption of the linear fused cross entropy loss | ||
############################################################################# | ||
|
||
|
||
def bench_memory_fused_linear_simpo_loss( | ||
input: SingleBenchmarkRunInput, | ||
) -> SingleBenchmarkRunOutput: | ||
B = input.x | ||
T = input.extra_benchmark_config["T"] | ||
H = input.extra_benchmark_config["H"] | ||
V = input.extra_benchmark_config["V"] | ||
dtype = input.extra_benchmark_config["dtype"] | ||
provider = input.kernel_provider | ||
|
||
device = "cuda" | ||
torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | ||
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | ||
|
||
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | ||
target = torch.randint(V, (B, T), dtype=torch.long, device=device) | ||
|
||
def fwd(): | ||
if provider == "liger": | ||
return liger_lm_head_simpo(_input, target) | ||
elif provider == "huggingface": | ||
return torch_lm_head_simpo(_input, target) | ||
|
||
def full(): | ||
y = fwd() | ||
y.backward() | ||
|
||
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | ||
return SingleBenchmarkRunOutput( | ||
y_20=mem_20, | ||
y_50=mem_50, | ||
y_80=mem_80, | ||
) | ||
|
||
|
||
# ############################################################################# | ||
# # Test the speed of the fused linear cross entropy loss | ||
# ############################################################################# | ||
|
||
|
||
def bench_speed_fused_linear_simpo_loss( | ||
input: SingleBenchmarkRunInput, | ||
) -> SingleBenchmarkRunOutput: | ||
B = input.x | ||
T = input.extra_benchmark_config["T"] | ||
H = input.extra_benchmark_config["H"] | ||
V = input.extra_benchmark_config["V"] | ||
dtype = input.extra_benchmark_config["dtype"] | ||
provider = input.kernel_provider | ||
mode = input.kernel_operation_mode | ||
|
||
device = "cuda" | ||
|
||
torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | ||
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | ||
|
||
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | ||
target = torch.randint(V, (B, T), dtype=torch.long, device=device) | ||
|
||
def fwd(): | ||
if provider == "liger": | ||
return liger_lm_head_simpo(_input, target) | ||
elif provider == "huggingface": | ||
return torch_lm_head_simpo(_input, target) | ||
|
||
if mode == "forward": | ||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
fwd, | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
elif mode == "backward": | ||
y = fwd() | ||
|
||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
lambda: y.backward(retain_graph=True), | ||
grad_to_none=[_input], | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
elif mode == "full": | ||
|
||
def full(): | ||
y = fwd() | ||
y.backward() | ||
|
||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
full, | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
return SingleBenchmarkRunOutput( | ||
y_20=ms_20, | ||
y_50=ms_50, | ||
y_80=ms_80, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_benchmark_script_args() | ||
|
||
common_configs = { | ||
"kernel_name": "fused_linear_simpo_loss", | ||
"x_name": "B", | ||
"x_label": "B", | ||
"x_values": [2**i for i in range(1, 5)], | ||
"kernel_providers": ["liger", "huggingface"], | ||
"extra_benchmark_configs": [ | ||
{ | ||
"T": 1024, | ||
"H": 4096, | ||
"V": 128256, | ||
"mode": "forward", | ||
"dtype": torch.bfloat16, | ||
} | ||
], | ||
"overwrite": args.overwrite, | ||
} | ||
|
||
run_benchmarks( | ||
bench_test_fn=bench_speed_fused_linear_simpo_loss, | ||
kernel_operation_modes=["forward", "full"], | ||
metric_name="speed", | ||
metric_unit="ms", | ||
**common_configs | ||
) | ||
run_benchmarks( | ||
bench_test_fn=bench_memory_fused_linear_simpo_loss, | ||
kernel_operation_modes=["full"], | ||
metric_name="memory", | ||
metric_unit="MB", | ||
**common_configs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch.nn.functional as F | ||
|
||
from liger_kernel.chunked_loss.fused_linear_preference import ( | ||
LigerFusedLinearPreferenceBase, | ||
) | ||
|
||
|
||
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): | ||
|
||
@staticmethod | ||
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5): | ||
""" | ||
Compute odds-ratio loss. | ||
Args: | ||
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). | ||
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). | ||
beta (float): Weight for the odds ratio loss. | ||
gamma (float): The simpo gamma, margin term. | ||
""" | ||
logits = beta * (chosen_logps - rejected_logps) - gamma | ||
loss = F.logsigmoid(logits).mean() | ||
return loss | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
_input, | ||
weight, | ||
target, | ||
bias=None, | ||
ignore_index=-100, | ||
beta=0.1, | ||
alpha=1.0, | ||
compute_nll_loss=False, | ||
compiled=True, | ||
gamma=0.5, | ||
): | ||
""" | ||
Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734 | ||
Handles both the forward and backward pass of the final linear layer with SimPO loss. | ||
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. | ||
""" | ||
|
||
return LigerFusedLinearPreferenceBase.forward( | ||
ctx, | ||
_input, | ||
weight, | ||
target, | ||
bias, | ||
loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn, | ||
compute_nll_loss=compute_nll_loss, | ||
ignore_index=ignore_index, | ||
alpha=alpha, | ||
beta=beta, | ||
compiled=compiled, | ||
gamma=gamma, | ||
) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
# Get gradients for _input, weight, bias, and target from the base class | ||
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] | ||
# Return these gradients, followed by None for the remaining inputs | ||
return *grads, None, None, None, None, None, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.