Skip to content

Commit

Permalink
add nn.module support for chunked loss function (#402)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
Same as title
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
  • Loading branch information
shivam15s authored Nov 21, 2024
1 parent 81d98ea commit 2a39f0d
Show file tree
Hide file tree
Showing 12 changed files with 686 additions and 38 deletions.
4 changes: 4 additions & 0 deletions src/liger_kernel/chunked_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
42 changes: 41 additions & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
Expand Down Expand Up @@ -46,10 +47,10 @@ def forward(
target,
bias,
loss_fn=LigerFusedLinearCPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
alpha=alpha,
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
)

Expand All @@ -59,3 +60,42 @@ def backward(ctx, grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None, None


class LigerFusedLinearCPOLoss(torch.nn.Module):
"""
Fused linear layer with CPO loss.
"""

def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
compute_nll_loss: bool = True,
compiled: bool = True,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearCPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.alpha,
self.compute_nll_loss,
self.compiled,
)
39 changes: 38 additions & 1 deletion src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
Expand Down Expand Up @@ -43,9 +44,9 @@ def forward(
target=target,
bias=bias,
loss_fn=LigerFusedLinearDPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
)

Expand All @@ -55,3 +56,39 @@ def backward(ctx, grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None


class LigerFusedLinearDPOLoss(torch.nn.Module):
"""
Fused linear layer with DPO loss.
"""

def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearDPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.compute_nll_loss,
self.compiled,
)
9 changes: 9 additions & 0 deletions src/liger_kernel/chunked_loss/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction

liger_fused_linear_orpo = LigerFusedLinearORPOFunction.apply
liger_fused_linear_dpo = LigerFusedLinearDPOFunction.apply
liger_fused_linear_cpo = LigerFusedLinearCPOFunction.apply
liger_fused_linear_simpo = LigerFusedLinearSimPOFunction.apply
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def forward(
bias=None,
loss_fn=None,
chunk_size=1,
compute_nll_loss=True,
ignore_index=-100,
alpha=1.0,
beta=0.1,
compute_nll_loss=True,
compiled=True,
**loss_kwargs,
):
Expand Down
40 changes: 38 additions & 2 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def forward(
ignore_index=-100,
beta=0.1,
compute_nll_loss=True,
compiled=False,
compiled=True,
):
"""
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
Expand All @@ -49,9 +49,9 @@ def forward(
target=target,
bias=bias,
loss_fn=LigerFusedLinearORPOFunction.preference_loss_fn,
compute_nll_loss=compute_nll_loss,
ignore_index=ignore_index,
beta=beta,
compute_nll_loss=compute_nll_loss,
compiled=compiled,
)

Expand All @@ -61,3 +61,39 @@ def backward(ctx, grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None


class LigerFusedLinearORPOLoss(torch.nn.Module):
"""
Fused linear layer with ORPO (Odds-Ratio Preference Optimization) loss.
"""

def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
compute_nll_loss: bool = True,
compiled: bool = True,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearORPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.compute_nll_loss,
self.compiled,
)
43 changes: 43 additions & 0 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_preference import (
Expand Down Expand Up @@ -62,3 +63,45 @@ def backward(ctx, grad_output):
grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4]
# Return these gradients, followed by None for the remaining inputs
return *grads, None, None, None, None, None, None


class LigerFusedLinearSimPOLoss(torch.nn.Module):
"""
Fused linear layer with SimPO loss.
"""

def __init__(
self,
ignore_index: int = -100,
beta: float = 0.1,
alpha: float = 1.0,
compute_nll_loss: bool = True,
compiled: bool = True,
gamma: float = 0.5,
):
"""
Args:
ignore_index (int): Index to ignore in the loss.
beta (float): Weight for the odds ratio loss.
"""
super().__init__()
self.ignore_index = ignore_index
self.beta = beta
self.alpha = alpha
self.compute_nll_loss = compute_nll_loss
self.compiled = compiled
self.gamma = gamma

def forward(self, lin_weight, _input, target, bias=None):
return LigerFusedLinearSimPOFunction.apply(
_input,
lin_weight,
target,
bias,
self.ignore_index,
self.beta,
self.alpha,
self.compute_nll_loss,
self.compiled,
self.gamma,
)
Loading

0 comments on commit 2a39f0d

Please sign in to comment.