Skip to content

Commit

Permalink
Add ignore_index and label to jsd and fl-jsd (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tcc0403 authored Oct 16, 2024
1 parent 3146916 commit 24a7efc
Show file tree
Hide file tree
Showing 9 changed files with 475 additions and 88 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ loss.backward()
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
| JSD | `liger_kernel.transformers.LigerJSD` |
| FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` |

- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup.
Expand All @@ -269,6 +270,8 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
- **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage.
- **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size.
- **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size.
- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192.


### Experimental Kernels

Expand Down
52 changes: 37 additions & 15 deletions benchmark/scripts/benchmark_fused_linear_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,40 @@


class TorchJSD(torch.nn.Module):
def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float):
def __init__(
self,
beta: float = 0.5,
ignore_index: int = -100,
dtype: torch.dtype = torch.float,
):
super(TorchJSD, self).__init__()
self.kl = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
self.beta = beta
self.ignore_index = ignore_index
self.dtype = dtype

def forward(
self,
log_q: torch.tensor, # input
log_p: torch.tensor, # target
log_q: torch.Tensor, # input
log_p: torch.Tensor, # target
label=None,
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
torch.log(m), log_q
)
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
1 - self.beta
) * self.kl(torch.log(m), log_q).sum(dim=-1)

if label is not None:
loss = torch.where(label != self.ignore_index, loss, 0.0)
n_non_ignore = (label != self.ignore_index).sum().item()
if n_non_ignore == 0:
loss = 0.0
else:
loss = (loss / n_non_ignore).sum()
else:
loss = (loss / log_q.shape[0]).sum()
return loss.to(self.dtype)


Expand All @@ -48,8 +65,9 @@ def __init__(
V: int,
dtype: torch.dtype,
device: torch.device,
temperature: float = 1.0,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
super().__init__()
self.student_lin = torch.nn.Linear(
Expand All @@ -58,16 +76,16 @@ def __init__(
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.jsd = TorchJSD(beta, dtype=dtype)
self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype)
self.temperature = temperature

def forward(self, student_input, teacher_input):
def forward(self, student_input, teacher_input, label=None):
student_logits = self.student_lin(student_input)
teacher_logits = self.teacher_lin(teacher_input)
student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1)
teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1)

return self.jsd(student_prob, teacher_prob)
return self.jsd(student_prob, teacher_prob, label)


class LigerLMHeadJSD(torch.nn.Module):
Expand All @@ -77,8 +95,9 @@ def __init__(
V: int,
dtype: torch.dtype,
device: torch.device,
temperature: float = 1.0,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
super().__init__()
self.student_lin = torch.nn.Linear(
Expand All @@ -87,14 +106,17 @@ def __init__(
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.fused_jsd = LigerFusedLinearJSD(beta, temperature)
self.fused_jsd = LigerFusedLinearJSD(
jsd_beta=beta, ignore_index=ignore_index, temperature=temperature
)

def forward(self, student_input, teacher_input):
def forward(self, student_input, teacher_input, label=None):
return self.fused_jsd(
student_input,
self.student_lin.weight,
teacher_input,
self.teacher_lin.weight,
label,
)


Expand Down
36 changes: 26 additions & 10 deletions benchmark/scripts/benchmark_jsd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn as nn
import triton
from utils import (
QUANTILES,
Expand All @@ -13,24 +12,41 @@
from liger_kernel.transformers.jsd import LigerJSD


class TorchJSD(nn.Module):
def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float):
class TorchJSD(torch.nn.Module):
def __init__(
self,
beta: float = 0.5,
ignore_index: int = -100,
dtype: torch.dtype = torch.float,
):
super(TorchJSD, self).__init__()
self.kl = nn.KLDivLoss(reduction="batchmean", log_target=True)
self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True)
self.beta = beta
self.ignore_index = ignore_index
self.dtype = dtype

def forward(
self,
log_q: torch.tensor, # input
log_p: torch.tensor, # target
log_q: torch.Tensor, # input
log_p: torch.Tensor, # target
label=None,
):
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1))
m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl(
torch.log(m), log_q
)
m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta)
loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + (
1 - self.beta
) * self.kl(torch.log(m), log_q).sum(dim=-1)

if label is not None:
loss = torch.where(label != self.ignore_index, loss, 0.0)
n_non_ignore = (label != self.ignore_index).sum().item()
if n_non_ignore == 0:
loss = 0.0
else:
loss = (loss / n_non_ignore).sum()
else:
loss = (loss / log_q.shape[0]).sum()
return loss.to(self.dtype)


Expand Down
46 changes: 38 additions & 8 deletions src/liger_kernel/ops/fused_linear_jsd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import triton

Expand All @@ -15,7 +17,10 @@ def fused_linear_jsd_forward(
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
has_label,
temperature,
):
device = student_input.device
Expand Down Expand Up @@ -46,6 +51,11 @@ def fused_linear_jsd_forward(
# we use fp32 for loss accumulator
loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device)

if has_label:
n_non_ignore = (shift_labels != ignore_index).sum().item()
else:
n_non_ignore = BT

for chunk_id in range(num_chunks):
start_idx = chunk_id * chunk_size
end_idx = min((chunk_id + 1) * chunk_size, BT)
Expand Down Expand Up @@ -81,10 +91,15 @@ def fused_linear_jsd_forward(
loss_stride=loss_1d_slice.stride(-2),
dX_ptr=student_prob_chunk,
dX_stride=student_prob_chunk.stride(-2),
label_ptr=(
shift_labels if has_label else torch.empty(1, device=device)
), # dummy ptr if no label
beta=jsd_beta,
n_rows=BT, # batchmean
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
n_cols=V,
BLOCK_SIZE=BLOCK_SIZE,
HAS_LABEL=has_label,
)
loss_1d[start_idx:end_idx] = loss_1d_slice
# gradients of prob_chunk in place, shape: chunk_size x V
Expand Down Expand Up @@ -157,12 +172,14 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
student_input,
student_weight,
teacher_input,
teacher_weight,
jsd_beta=0.5,
temperature=1.0,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
teacher_weight: torch.Tensor,
shift_labels: Optional[torch.Tensor] = None,
jsd_beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
"""
Args:
Expand All @@ -171,18 +188,31 @@ def forward(
student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size
teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension.
teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5`
ignore_index (int): the index to ignore. Default: -100
temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0`
Returns:
loss (torch.Tensor): generalized JSD
"""
has_label = False
if shift_labels is not None:
assert shift_labels.shape == (
teacher_input.shape[0],
), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
shift_labels = shift_labels.contiguous()
has_label = True

loss, grad_input, grad_weight = fused_linear_jsd_forward(
student_input,
student_weight,
teacher_input,
teacher_weight,
shift_labels,
jsd_beta,
ignore_index,
has_label,
temperature,
)
# downcast to dtype and store for backward
Expand All @@ -198,4 +228,4 @@ def backward(ctx, grad_output):
grad_input, grad_weight = fused_linear_jsd_backward(
grad_output, grad_input, grad_weight
)
return (grad_input, grad_weight, None, None, None, None)
return (grad_input, grad_weight, None, None, None, None, None, None)
Loading

0 comments on commit 24a7efc

Please sign in to comment.