-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Chunked SimPO Loss #386
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
0ac4e9b
Initial commit
pramodith 6357eed
Add SimPO Loss
pramodith cd5c0da
Merge branch 'main' into pramodith/chunked_simpo_loss
pramodith 965ee55
Fix merge
pramodith bf69261
Merge branch 'main' into pramodith/chunked_simpo_loss
ByronHsu 98706f6
Merge branch 'main' into pramodith/chunked_simpo_loss
pramodith 3ef9bad
Fix checkstyle
pramodith 0534bb3
compile just once.
pramodith 7422b7e
fix checkstyle
pramodith File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
import triton | ||
from utils import ( | ||
QUANTILES, | ||
SingleBenchmarkRunInput, | ||
SingleBenchmarkRunOutput, | ||
_test_memory, | ||
parse_benchmark_script_args, | ||
run_benchmarks, | ||
) | ||
|
||
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction | ||
|
||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) | ||
|
||
|
||
class TorchLMHeadSimPO(torch.nn.Module): | ||
"""Ground truth implementation of the linear fused with torch based cross entropy loss. | ||
|
||
:param H: hidden size | ||
:param V: vocab size | ||
:param ignore_index: index to ignore | ||
:param reduction: reduction method | ||
""" | ||
|
||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | ||
from test.chunked_loss.test_cpo_loss import HFCPOLoss | ||
|
||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=False, dtype=dtype | ||
) | ||
self.simpo_loss = HFCPOLoss(loss_type="simpo").get_batch_loss_metrics | ||
|
||
def forward(self, x, y): | ||
return self.simpo_loss(x, self.lin.weight, y) | ||
|
||
|
||
class LigerLMHeadSimPO(torch.nn.Module): | ||
def __init__(self, H: int, V: int, dtype: torch.dtype, ignore_index: int = -100): | ||
super().__init__() | ||
self.lin = torch.nn.Linear( | ||
in_features=H, out_features=V, bias=False, dtype=dtype | ||
) | ||
self.simpo_loss = LigerFusedLinearSimPOFunction.apply | ||
|
||
def forward(self, x, y): | ||
return self.simpo_loss(x, self.lin.weight, y) | ||
|
||
|
||
############################################################################# | ||
# Test the memory consumption of the linear fused cross entropy loss | ||
############################################################################# | ||
|
||
|
||
def bench_memory_fused_linear_simpo_loss( | ||
input: SingleBenchmarkRunInput, | ||
) -> SingleBenchmarkRunOutput: | ||
B = input.x | ||
T = input.extra_benchmark_config["T"] | ||
H = input.extra_benchmark_config["H"] | ||
V = input.extra_benchmark_config["V"] | ||
dtype = input.extra_benchmark_config["dtype"] | ||
provider = input.kernel_provider | ||
|
||
device = "cuda" | ||
torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | ||
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | ||
|
||
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | ||
target = torch.randint(V, (B, T), dtype=torch.long, device=device) | ||
|
||
def fwd(): | ||
if provider == "liger": | ||
return liger_lm_head_simpo(_input, target) | ||
elif provider == "huggingface": | ||
return torch_lm_head_simpo(_input, target) | ||
|
||
def full(): | ||
y = fwd() | ||
y.backward() | ||
|
||
mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) | ||
return SingleBenchmarkRunOutput( | ||
y_20=mem_20, | ||
y_50=mem_50, | ||
y_80=mem_80, | ||
) | ||
|
||
|
||
# ############################################################################# | ||
# # Test the speed of the fused linear cross entropy loss | ||
# ############################################################################# | ||
|
||
|
||
def bench_speed_fused_linear_simpo_loss( | ||
input: SingleBenchmarkRunInput, | ||
) -> SingleBenchmarkRunOutput: | ||
B = input.x | ||
T = input.extra_benchmark_config["T"] | ||
H = input.extra_benchmark_config["H"] | ||
V = input.extra_benchmark_config["V"] | ||
dtype = input.extra_benchmark_config["dtype"] | ||
provider = input.kernel_provider | ||
mode = input.kernel_operation_mode | ||
|
||
device = "cuda" | ||
|
||
torch_lm_head_simpo = TorchLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | ||
liger_lm_head_simpo = LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device) | ||
|
||
_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) | ||
target = torch.randint(V, (B, T), dtype=torch.long, device=device) | ||
|
||
def fwd(): | ||
if provider == "liger": | ||
return liger_lm_head_simpo(_input, target) | ||
elif provider == "huggingface": | ||
return torch_lm_head_simpo(_input, target) | ||
|
||
if mode == "forward": | ||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
fwd, | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
elif mode == "backward": | ||
y = fwd() | ||
|
||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
lambda: y.backward(retain_graph=True), | ||
grad_to_none=[_input], | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
elif mode == "full": | ||
|
||
def full(): | ||
y = fwd() | ||
y.backward() | ||
|
||
ms_50, ms_20, ms_80 = triton.testing.do_bench( | ||
full, | ||
rep=100, | ||
quantiles=QUANTILES, | ||
) | ||
return SingleBenchmarkRunOutput( | ||
y_20=ms_20, | ||
y_50=ms_50, | ||
y_80=ms_80, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_benchmark_script_args() | ||
|
||
common_configs = { | ||
"kernel_name": "fused_linear_simpo_loss", | ||
"x_name": "B", | ||
"x_label": "B", | ||
"x_values": [2**i for i in range(1, 5)], | ||
"kernel_providers": ["liger", "huggingface"], | ||
"extra_benchmark_configs": [ | ||
{ | ||
"T": 1024, | ||
"H": 4096, | ||
"V": 128256, | ||
"mode": "forward", | ||
"dtype": torch.bfloat16, | ||
} | ||
], | ||
"overwrite": args.overwrite, | ||
} | ||
|
||
run_benchmarks( | ||
bench_test_fn=bench_speed_fused_linear_simpo_loss, | ||
kernel_operation_modes=["forward", "full"], | ||
metric_name="speed", | ||
metric_unit="ms", | ||
**common_configs | ||
) | ||
run_benchmarks( | ||
bench_test_fn=bench_memory_fused_linear_simpo_loss, | ||
kernel_operation_modes=["full"], | ||
metric_name="memory", | ||
metric_unit="MB", | ||
**common_configs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch.nn.functional as F | ||
|
||
from liger_kernel.chunked_loss.fused_linear_preference import ( | ||
LigerFusedLinearPreferenceBase, | ||
) | ||
|
||
|
||
class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): | ||
|
||
@staticmethod | ||
def preference_loss_fn(chosen_logps, rejected_logps, beta=0.1, gamma=0.5): | ||
""" | ||
Compute odds-ratio loss. | ||
Args: | ||
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). | ||
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). | ||
beta (float): Weight for the odds ratio loss. | ||
gamma (float): The simpo gamma, margin term. | ||
""" | ||
logits = beta * (chosen_logps - rejected_logps) - gamma | ||
loss = F.logsigmoid(logits).mean() | ||
return loss | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
_input, | ||
weight, | ||
target, | ||
bias=None, | ||
ignore_index=-100, | ||
beta=0.1, | ||
alpha=1.0, | ||
compute_nll_loss=False, | ||
compiled=True, | ||
gamma=0.5, | ||
): | ||
""" | ||
Fused linear layer with SimPO (Simple Preference Optimization) loss. https://arxiv.org/pdf/2405.14734 | ||
Handles both the forward and backward pass of the final linear layer with SimPO loss. | ||
Inspired from LigerFusedLinearCrossEntropyFunction (https://arxiv.org/abs/2410.10989) which fuses final linear layer and CE loss. | ||
""" | ||
|
||
return LigerFusedLinearPreferenceBase.forward( | ||
ctx, | ||
_input, | ||
weight, | ||
target, | ||
bias, | ||
loss_fn=LigerFusedLinearSimPOFunction.preference_loss_fn, | ||
compute_nll_loss=compute_nll_loss, | ||
ignore_index=ignore_index, | ||
alpha=alpha, | ||
beta=beta, | ||
compiled=compiled, | ||
gamma=gamma, | ||
) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
# Get gradients for _input, weight, bias, and target from the base class | ||
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] | ||
# Return these gradients, followed by None for the remaining inputs | ||
return *grads, None, None, None, None, None, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be normalized by length, as you said in the description?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nvm, the logps are already averaged