Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Chunked SimPO Loss #386

Merged
merged 9 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -691,3 +691,27 @@ fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.31445
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_cpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-14 16:57:37,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,2,30.28438377380371,30.107013702392578,30.284786224365234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,4,58.80876922607422,58.80876922607422,58.80876922607422,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,8,117.96163177490234,117.96163177490234,117.96163177490234,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,liger,forward,speed,ms,B,B,16,235.60794067382812,235.60794067382812,235.60794067382812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:26,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,2,14.513839721679688,14.510687828063965,14.517855644226074,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,4,28.78099250793457,28.72719383239746,28.792186737060547,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,8,52.5733757019043,52.5733757019043,52.5733757019043,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,huggingface,forward,speed,ms,B,B,16,104.44764709472656,104.44764709472656,104.44764709472656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:27:56,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,2,31.566062927246094,31.457612991333008,31.674514770507812,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,4,61.4403190612793,61.4403190612793,61.4403190612793,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,8,119.97705841064453,119.97705841064453,119.97705841064453,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,liger,full,speed,ms,B,B,16,238.13417053222656,238.13417053222656,238.13417053222656,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:28:27,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,2,39.811119079589844,39.65474319458008,39.96749496459961,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,4,77.20928192138672,77.20928192138672,77.20928192138672,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,8,153.6952667236328,153.6952667236328,153.6952667236328,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,huggingface,full,speed,ms,B,B,16,307.7382507324219,307.7382507324219,307.7382507324219,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:00,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,2,7675.3291015625,7675.3291015625,7675.3291015625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,4,7723.3447265625,7723.3447265625,7723.3447265625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,8,7819.3759765625,7819.3759765625,7819.3759765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,liger,full,memory,MB,B,B,16,8011.4384765625,8011.4384765625,8011.4384765625,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:29:33,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314453125,8645.314453125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
191 changes: 191 additions & 0 deletions benchmark/scripts/benchmark_simpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
import sys

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchLMHeadSimPO(torch.nn.Module):
"""Ground truth implementation of the linear fused with torch based cross entropy loss.

:param H: hidden size
:param V: vocab size
:param ignore_index: index to ignore
:param reduction: reduction method
"""

def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
from test.chunked_loss.test_cpo_loss import HFCPOLoss

super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics

def forward(self, x, y):
return self.simpo_loss(x, self.lin.weight, y)


class LigerLMHeadSimPO(torch.nn.Module):
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype
)
self.simpo_loss = LigerFusedLinearSimPOFunction.apply

def forward(self, x, y):
return self.simpo_loss(x, self.lin.weight, y)


#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################


def bench_memory_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

device = "cuda"
torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_simpo(_input, target)
elif provider == "huggingface":
return torch_lm_head_simpo(_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################


def bench_speed_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

device = "cuda"

torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_simpo(_input, target)
elif provider == "huggingface":
return torch_lm_head_simpo(_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)
return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "fused_linear_simpo_loss",
"x_name": "B",
"x_label": "B",
"x_values": [2**i for i in range(1, 5)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 1024,
"H": 4096,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_fused_linear_simpo_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)
run_benchmarks(
bench_test_fn=bench_memory_fused_linear_simpo_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
8 changes: 6 additions & 2 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def forward(
alpha=1.0,
beta=0.1,
compiled=True,
**loss_kwargs,
):
"""
Base class for fused linear layer with preference loss.
Expand All @@ -49,6 +50,7 @@ def forward(
alpha (float): Weight for the NLL loss.
beta (float): Weight for the odds ratio loss.
compiled (bool): Whether to use torch compile for chunk accumulation.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
CHUNK_SIZE = chunk_size
Expand All @@ -68,6 +70,7 @@ def forward(
beta=beta,
compute_nll_loss=compute_nll_loss,
full_target=target,
**loss_kwargs,
)

def accumulate_chunk(input_chunk, target_chunk):
Expand All @@ -94,6 +97,9 @@ def accumulate_chunk(input_chunk, target_chunk):
loss_acc.add_(chunk_loss)
return chunk_grad_input

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

len_chosen = target.shape[0] // 2
_chosen_input_chunks = torch.chunk(_input[:len_chosen], chunks=chunks, dim=0)
_chosen_target_chunks = torch.chunk(target[:len_chosen], chunks=chunks, dim=0)
Expand All @@ -116,8 +122,6 @@ def accumulate_chunk(input_chunk, target_chunk):
[chosen_target_chunk, rejected_target_chunk], dim=0
)

if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)
grad_input = accumulate_chunk(input_chunk, target_chunk)

grad_chosen_inputs.append(grad_input[: chosen_target_chunk.shape[0]])
Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compiled=True,
compiled=False,
):
"""
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
Expand Down
64 changes: 64 additions & 0 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
LigerFusedLinearPreferenceBase,
)


class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5):
"""
Compute odds-ratio loss.
Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
beta (float): Weight for the odds ratio loss.
gamma (float): The simpo gamma, margin term.
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be normalized by length, as you said in the description?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm, the logps are already averaged

loss = F.logsigmoid(logits).mean()
return loss

@staticmethod
def forward(
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
beta=0.1,
alpha=1.0,
compute_nll_loss=False,
compiled=True,
gamma=0.5,
):
"""
Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734
Handles both the forward and backward pass of the final linear layer with SimPO loss.
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss.
"""

return LigerFusedLinearPreferenceBase.forward(
ctx,
_input,
weight,
target,
bias,
loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
compiled=compiled,
gamma=gamma,
)

@staticmethod
def backward(ctx, grad_output):
# Get gradients for _input, weight, bias, and target from the base class
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None, None, None
12 changes: 10 additions & 2 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ def __init__(
beta: float = 0.1,
ignore_index: int = -100,
label_smoothing: float = 0.0,
simpo_gamma: float = 0.5,
loss_type: str = "sigmoid",
):
super().__init__(alpha=alpha, beta=beta, ignore_index=ignore_index)
# Sigmoid defaults to the CPO loss defined in the paper listed above.
self.loss_type = "sigmoid"
self.loss_type = loss_type
self.label_smoothing = label_smoothing
self.simpo_gamma = simpo_gamma

def alignment_loss(
self,
Expand Down Expand Up @@ -55,6 +58,12 @@ def alignment_loss(
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
elif self.loss_type == "simpo":
logits = logits - (self.simpo_gamma / self.beta)
losses = (
F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
+ F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']"
Expand All @@ -66,7 +75,6 @@ def alignment_loss(
@pytest.mark.parametrize(
"B, T, H, V",
[
# (1, 2, 12, 128),
(8, 128, 1024, 4096),
(3, 47, 31, 123), # random shape
],
Expand Down
Loading
Loading