diff --git a/README.md b/README.md index bbd8d03d4..c6ff6ee5d 100644 --- a/README.md +++ b/README.md @@ -254,6 +254,7 @@ loss.backward() | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`| | KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` | | JSD | `liger_kernel.transformers.LigerJSD` | +| FusedLinearJSD | `liger_kernel.transformers.LigerFusedLinearJSD` | - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction. - **LayerNorm**: [LayerNorm](https://arxiv.org/pdf/1607.06450), which centers and normalizes activations across the feature dimension, is implemented by fusing the centering, normalization and scaling steps into a single Triton kernel, and achieves ~2X speedup. @@ -269,6 +270,8 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ - **FusedLinearCrossEntropy**: Peak memory usage of cross entropy loss is further improved by fusing the model head with the CE loss and chunking the input for block-wise loss and gradient calculation, a technique inspired by [Efficient Cross Entropy](https://github.com/mgmalek/efficient_cross_entropy). It achieves >4X memory reduction for 128k vocab size. **This is highly effective for large batch size, large sequence length, and large vocabulary sizes.** Please refer to the [Medusa example](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) for individual kernel usage. - **KLDivergence**: [KL Divergence](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) is implemented by fusing the forward into a single triton kernel, with reduction done outside the kernel. It achieves ~1.5X speed and ~15% memory reduction for 128K vocab size. - **JSD**: [Generalized JSD](https://arxiv.org/pdf/2306.13649) (Jensen-Shannon divergence), is implemented by computing both the loss and gradient in the forward pass. It achieves ~1.5X speed and ~54% memory reduction for 128k vocab size. +- **FusedLinearJSD**: Peak memory usage of JSD loss is further improved by fusing the model head with the model head with the JSD and chunking the input for block-wise loss and gradient calculation. It achieves ~85% memory reduction for 128k vocab size where batch size $\times$ sequence length is 8192. + ### Experimental Kernels diff --git a/benchmark/scripts/benchmark_fused_linear_jsd.py b/benchmark/scripts/benchmark_fused_linear_jsd.py index ee32428a7..7f652de8a 100644 --- a/benchmark/scripts/benchmark_fused_linear_jsd.py +++ b/benchmark/scripts/benchmark_fused_linear_jsd.py @@ -13,23 +13,40 @@ class TorchJSD(torch.nn.Module): - def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float): + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + dtype: torch.dtype = torch.float, + ): super(TorchJSD, self).__init__() - self.kl = torch.nn.KLDivLoss(reduction="batchmean", log_target=True) + self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True) self.beta = beta + self.ignore_index = ignore_index self.dtype = dtype def forward( self, - log_q: torch.tensor, # input - log_p: torch.tensor, # target + log_q: torch.Tensor, # input + log_p: torch.Tensor, # target + label=None, ): log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) - m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta) - loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl( - torch.log(m), log_q - ) + m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( + 1 - self.beta + ) * self.kl(torch.log(m), log_q).sum(dim=-1) + + if label is not None: + loss = torch.where(label != self.ignore_index, loss, 0.0) + n_non_ignore = (label != self.ignore_index).sum().item() + if n_non_ignore == 0: + loss = 0.0 + else: + loss = (loss / n_non_ignore).sum() + else: + loss = (loss / log_q.shape[0]).sum() return loss.to(self.dtype) @@ -48,8 +65,9 @@ def __init__( V: int, dtype: torch.dtype, device: torch.device, - temperature: float = 1.0, beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, ): super().__init__() self.student_lin = torch.nn.Linear( @@ -58,16 +76,16 @@ def __init__( self.teacher_lin = torch.nn.Linear( in_features=H, out_features=V, bias=False, dtype=dtype, device=device ) - self.jsd = TorchJSD(beta, dtype=dtype) + self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype) self.temperature = temperature - def forward(self, student_input, teacher_input): + def forward(self, student_input, teacher_input, label=None): student_logits = self.student_lin(student_input) teacher_logits = self.teacher_lin(teacher_input) student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1) teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1) - return self.jsd(student_prob, teacher_prob) + return self.jsd(student_prob, teacher_prob, label) class LigerLMHeadJSD(torch.nn.Module): @@ -77,8 +95,9 @@ def __init__( V: int, dtype: torch.dtype, device: torch.device, - temperature: float = 1.0, beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, ): super().__init__() self.student_lin = torch.nn.Linear( @@ -87,14 +106,17 @@ def __init__( self.teacher_lin = torch.nn.Linear( in_features=H, out_features=V, bias=False, dtype=dtype, device=device ) - self.fused_jsd = LigerFusedLinearJSD(beta, temperature) + self.fused_jsd = LigerFusedLinearJSD( + jsd_beta=beta, ignore_index=ignore_index, temperature=temperature + ) - def forward(self, student_input, teacher_input): + def forward(self, student_input, teacher_input, label=None): return self.fused_jsd( student_input, self.student_lin.weight, teacher_input, self.teacher_lin.weight, + label, ) diff --git a/benchmark/scripts/benchmark_jsd.py b/benchmark/scripts/benchmark_jsd.py index 44e8cc269..272008315 100644 --- a/benchmark/scripts/benchmark_jsd.py +++ b/benchmark/scripts/benchmark_jsd.py @@ -1,5 +1,4 @@ import torch -import torch.nn as nn import triton from utils import ( QUANTILES, @@ -13,24 +12,41 @@ from liger_kernel.transformers.jsd import LigerJSD -class TorchJSD(nn.Module): - def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float): +class TorchJSD(torch.nn.Module): + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + dtype: torch.dtype = torch.float, + ): super(TorchJSD, self).__init__() - self.kl = nn.KLDivLoss(reduction="batchmean", log_target=True) + self.kl = torch.nn.KLDivLoss(reduction="none", log_target=True) self.beta = beta + self.ignore_index = ignore_index self.dtype = dtype def forward( self, - log_q: torch.tensor, # input - log_p: torch.tensor, # target + log_q: torch.Tensor, # input + log_p: torch.Tensor, # target + label=None, ): log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) - m = torch.lerp(torch.exp(log_p), torch.exp(log_q), self.beta) - loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl( - torch.log(m), log_q - ) + m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( + 1 - self.beta + ) * self.kl(torch.log(m), log_q).sum(dim=-1) + + if label is not None: + loss = torch.where(label != self.ignore_index, loss, 0.0) + n_non_ignore = (label != self.ignore_index).sum().item() + if n_non_ignore == 0: + loss = 0.0 + else: + loss = (loss / n_non_ignore).sum() + else: + loss = (loss / log_q.shape[0]).sum() return loss.to(self.dtype) diff --git a/src/liger_kernel/ops/fused_linear_jsd.py b/src/liger_kernel/ops/fused_linear_jsd.py index 9f191c14a..34cb185c1 100644 --- a/src/liger_kernel/ops/fused_linear_jsd.py +++ b/src/liger_kernel/ops/fused_linear_jsd.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import triton @@ -15,7 +17,10 @@ def fused_linear_jsd_forward( student_weight, teacher_input, teacher_weight, + shift_labels, jsd_beta, + ignore_index, + has_label, temperature, ): device = student_input.device @@ -46,6 +51,11 @@ def fused_linear_jsd_forward( # we use fp32 for loss accumulator loss_1d = torch.zeros((BT, V), dtype=torch.float32, device=device) + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + for chunk_id in range(num_chunks): start_idx = chunk_id * chunk_size end_idx = min((chunk_id + 1) * chunk_size, BT) @@ -81,10 +91,15 @@ def fused_linear_jsd_forward( loss_stride=loss_1d_slice.stride(-2), dX_ptr=student_prob_chunk, dX_stride=student_prob_chunk.stride(-2), + label_ptr=( + shift_labels if has_label else torch.empty(1, device=device) + ), # dummy ptr if no label beta=jsd_beta, - n_rows=BT, # batchmean + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, n_cols=V, BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, ) loss_1d[start_idx:end_idx] = loss_1d_slice # gradients of prob_chunk in place, shape: chunk_size x V @@ -157,12 +172,14 @@ class LigerFusedLinearJSDFunction(torch.autograd.Function): @staticmethod def forward( ctx, - student_input, - student_weight, - teacher_input, - teacher_weight, - jsd_beta=0.5, - temperature=1.0, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, + jsd_beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, ): """ Args: @@ -171,18 +188,31 @@ def forward( student_weight (torch.tensor): the last projection layer in student model, with shape (V, H), where V is vocab size teacher_input (torch.tensor): input of the last projection layer in teacher model, with shape (B*T, H), where B is batch size, T is sequence length, H is hidden dimension. teacher_weight (torch.tensor): the last projection layer in teacher model, with shape (V, H), where V is vocab size + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + ignore_index (int): the index to ignore. Default: -100 temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` Returns: loss (torch.Tensor): generalized JSD """ + has_label = False + if shift_labels is not None: + assert shift_labels.shape == ( + teacher_input.shape[0], + ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + shift_labels = shift_labels.contiguous() + has_label = True + loss, grad_input, grad_weight = fused_linear_jsd_forward( student_input, student_weight, teacher_input, teacher_weight, + shift_labels, jsd_beta, + ignore_index, + has_label, temperature, ) # downcast to dtype and store for backward @@ -198,4 +228,4 @@ def backward(ctx, grad_output): grad_input, grad_weight = fused_linear_jsd_backward( grad_output, grad_input, grad_weight ) - return (grad_input, grad_weight, None, None, None, None) + return (grad_input, grad_weight, None, None, None, None, None, None) diff --git a/src/liger_kernel/ops/jsd.py b/src/liger_kernel/ops/jsd.py index e2a77d1d3..33ec2498c 100644 --- a/src/liger_kernel/ops/jsd.py +++ b/src/liger_kernel/ops/jsd.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import triton import triton.language as tl @@ -15,10 +17,13 @@ def _jsd_kernel( loss_stride, dX_ptr, dX_stride, + label_ptr, beta, - n_rows, + n_non_ignore, + ignore_index: tl.constexpr, n_cols, BLOCK_SIZE: tl.constexpr, + HAS_LABEL: tl.constexpr, ): # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 @@ -29,6 +34,15 @@ def _jsd_kernel( dX_ptr += pid * dX_stride Y_ptr += pid * Y_stride loss_ptr += pid * loss_stride + label_ptr += pid + + if HAS_LABEL: + label = tl.load(label_ptr) + if label == ignore_index: + for i in range(0, n_cols, BLOCK_SIZE): + offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(dX_ptr + offsets, 0.0, mask=offsets < n_cols) + return for i in range(0, n_cols, BLOCK_SIZE): offsets = i + tl.arange(0, BLOCK_SIZE) @@ -43,17 +57,17 @@ def _jsd_kernel( loss = beta * P * Y + (1 - beta) * Q * X - M * log_M # reduction == "batchmean" - loss = loss / n_rows + loss = loss / n_non_ignore tl.store(loss_ptr + offsets, loss, mask=mask) - dX = (1 - beta) * Q * (X - log_M) / n_rows + dX = (1 - beta) * Q * (X - log_M) / n_non_ignore tl.store(dX_ptr + offsets, dX, mask=mask) MAX_FUSED_SIZE = 65536 -def jsd_forward(_input, target, beta): +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): BT, V = _input.shape n_rows = BT BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) @@ -61,6 +75,11 @@ def jsd_forward(_input, target, beta): loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) dX = torch.empty_like(_input) + if has_label: + n_non_ignore = (shift_labels != ignore_index).sum().item() + else: + n_non_ignore = BT + _jsd_kernel[(n_rows,)]( X_ptr=_input, # input in logspace, X = log Q X_stride=_input.stride(-2), @@ -70,10 +89,15 @@ def jsd_forward(_input, target, beta): loss_stride=loss.stride(-2), dX_ptr=dX, dX_stride=dX.stride(-2), + label_ptr=( + shift_labels if has_label else torch.empty(1, device=_input.device) + ), # dummy ptr if no label beta=beta, - n_rows=n_rows, + n_non_ignore=n_non_ignore, + ignore_index=ignore_index, n_cols=V, BLOCK_SIZE=BLOCK_SIZE, + HAS_LABEL=has_label, ) loss = torch.sum(loss) @@ -109,18 +133,32 @@ def forward( ctx, _input: torch.Tensor, target: torch.Tensor, + shift_labels: Optional[torch.Tensor] = None, beta: float = 0.5, + ignore_index: int = -100, ) -> torch.Tensor: """ Args: _input (torch.Tensor): predict values with shape (BT, V) in logspace target (torch.Tensor): ground truth values with shape (BT, V) in logspace + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. beta (float): coefficient beta of generalized JSD in the open interval (0, 1) + ignore_index (int): the index to ignore. Default: -100 Returns: loss (torch.Tensor): generalized JSD """ - loss, dX = jsd_forward(_input, target, beta) + has_label = False + if shift_labels is not None: + assert shift_labels.shape == ( + _input.shape[0], + ), f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" + shift_labels = shift_labels.contiguous() + has_label = True + + loss, dX = jsd_forward( + _input, target, shift_labels, beta, ignore_index, has_label + ) ctx.save_for_backward(dX) return loss @@ -133,4 +171,6 @@ def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: dX, None, None, + None, + None, ) diff --git a/src/liger_kernel/transformers/fused_linear_jsd.py b/src/liger_kernel/transformers/fused_linear_jsd.py index d9579fe4b..001174cc2 100644 --- a/src/liger_kernel/transformers/fused_linear_jsd.py +++ b/src/liger_kernel/transformers/fused_linear_jsd.py @@ -1,9 +1,11 @@ -import torch.nn as nn +from typing import Optional + +import torch from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction -class LigerFusedLinearJSD(nn.Module): +class LigerFusedLinearJSD(torch.nn.Module): r"""Fusing the last linear layer with generalized JSD Handle the forward and backward pass of the final linear layer via JSD by avoiding @@ -11,6 +13,7 @@ class LigerFusedLinearJSD(nn.Module): Args: jsd_beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + ignore_index (int): The index to ignore in the target. Default: `-100` temperature (float): temperature in softmax function to control the output probability distribution. Default: `1.0` Shape: @@ -18,24 +21,54 @@ class LigerFusedLinearJSD(nn.Module): - student_weight: :math:`(V, H)`, where V is vocab size. - teacher_input: :math:`(BT, H')`, where H' is hidden dimension of the teacher model. - teacher_weight: :math:`(V, H')`, where hidden size H and H' can be different. + - shift_labels: :math:`(BT,)` - Output: a scalar. Examples: ```python - >>> (B, T, H, V) = (2, 2, 3, 5) + >>> (B, T, H_s, H_t, V) = (2, 2, 3, 5, 10) >>> fused_jsd = LigerFusedLinearJSD(jsd_beta=0.1, temperature=2.0) >>> # generate inputs and weights - >>> student_input = torch.rand(B * T, H, device="cuda", requires_grad=True) - >>> student_lin = torch.nn.Linear(H, V, bias=False, device="cuda") + >>> student_input = torch.rand(B * T, H_s, device="cuda", requires_grad=True) + >>> student_lin = torch.nn.Linear(H_s, V, bias=False, device="cuda") >>> # teacher input doesn't require grad, hidden_dim can be different from student's - >>> teacher_input = torch.rand(B * T, H * 2, device="cuda") - >>> teacher_lin = torch.nn.Linear(H * 2, V, bias=False, device="cuda") + >>> teacher_input = torch.rand(B * T, H_t, device="cuda") + >>> teacher_lin = torch.nn.Linear(H_t, V, bias=False, device="cuda") >>> output = fused_jsd(student_input, student_lin.weight, teacher_input, teacher_lin.weight) >>> output.backward() + >>> + >>> # Example with labels for supervised fine-tuning (SFT) context: + >>> + >>> # Assume hidden_states, lm_heads and corresponding labels are given + >>> student_lm_head = torch.nn.Linear(H_s, V, bias=False) + >>> student_hidden_states = torch.randn(B * T, H_s, requires_grad=True).log_softmax(dim=-1) + >>> teacher_lm_head = torch.nn.Linear(H_t, V, bias=False) + >>> teacher_hidden_states = torch.randn(B * T, H_t).log_softmax(dim=-1) + >>> labels = torch.randint(0, V, (B * T,), torch.long) + >>> + >>> # Shift so that tokens < n predict n + >>> shift_student_hidden_states = student_hidden_states[..., :-1, :].contiguous() + >>> shift_teacher_hidden_states = teacher_hidden_states[..., :-1, :].contiguous() + >>> shift_labels = labels[..., 1:].contiguous() + >>> + >>> # Flatten tokens + >>> shift_student_hidden_states = shift_student_hidden_states.view(-1, V) + >>> shift_teacher_hidden_states = shift_teacher_hidden_states.view(-1, V) + >>> shift_labels = shift_labels.view(-1) + >>> + >>> # Calculate loss + >>> loss_fct = LigerJSD(beta=0.1) + >>> loss = loss_fct( + >>> shift_studetn_hidden_states, + >>> student_lm_head.weight, + >>> shift_teacher_hidden_states, + >>> teacher_lm_head.weight, + >>> shift_labels + >>> ) ``` """ - def __init__(self, jsd_beta=0.5, temperature=1.0): + def __init__(self, jsd_beta=0.5, ignore_index=-100, temperature=1.0): super().__init__() assert ( jsd_beta > 0 and jsd_beta < 1 @@ -43,19 +76,23 @@ def __init__(self, jsd_beta=0.5, temperature=1.0): assert temperature != 0, "temperature cannot be 0." self.jsd_beta = jsd_beta self.temperature = temperature + self.ignore_index = ignore_index def forward( self, - student_input, - student_weight, - teacher_input, - teacher_weight, + student_input: torch.Tensor, + student_weight: torch.Tensor, + teacher_input: torch.Tensor, + teacher_weight: torch.Tensor, + shift_labels: Optional[torch.LongTensor], ): return LigerFusedLinearJSDFunction.apply( student_input, student_weight, teacher_input, teacher_weight, + shift_labels, self.jsd_beta, + self.ignore_index, self.temperature, ) diff --git a/src/liger_kernel/transformers/jsd.py b/src/liger_kernel/transformers/jsd.py index 2040c80a6..e218ca84b 100644 --- a/src/liger_kernel/transformers/jsd.py +++ b/src/liger_kernel/transformers/jsd.py @@ -1,9 +1,11 @@ -import torch.nn as nn +from typing import Optional + +import torch from liger_kernel.ops.jsd import LigerJSDFunction -class LigerJSD(nn.Module): +class LigerJSD(torch.nn.Module): r"""The generalized Jensen-Shannon Divergence. .. math:: JSD(\beta)(P || Q) @@ -17,28 +19,57 @@ class LigerJSD(nn.Module): Args: beta (float): coefficient beta of generalized JSD in the open interval (0, 1). Default: `0.5` + ignore_index (int): The index to ignore in the target. Default: `-100` Shape: - - Input: :math:`(*)`, where :math:`*` means any number of dimensions. - - Target: :math:`(*)`, same shape as the input. + - Input: :math:`(BT, V)`, where B is batch size, T is sequence length, V is vocab size. + - Target: :math:`(BT, V)`, same shape as the input. + - shift_labels (Optional): :math:`(BT,)` - Output: a scalar. Examples: ```python + >>> (B, T, V) = (2, 2, 5) >>> jsd = LigerJSD(beta=0.1) >>> # input should be a distribution in the log space - >>> input = torch.randn(3, 5, requires_grad=True).log_softmax(dim=-1) - >>> target = torch.randn(3, 5, requires_grad=True).log_softmax(dim=-1) + >>> input = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) + >>> target = torch.randn(B * T, V).log_softmax(dim=-1) >>> output = jsd(input, target) + >>> + >>> # Example with labels for supervised fine-tuning (SFT) context + >>> # Assume logits and corresponding labels are given + >>> student_logits = torch.randn(B * T, V, requires_grad=True).log_softmax(dim=-1) + >>> teacher_logits = torch.randn(B * T, V).log_softmax(dim=-1) + >>> labels = torch.randint(0, V, (B * T,), torch.long) + >>> # Shift so that tokens < n predict n + >>> shift_student_logits = student_logits[..., :-1, :].contiguous() + >>> shift_teacher_logits = teacher_logits[..., :-1, :].contiguous() + >>> shift_labels = labels[..., 1:].contiguous() + >>> # Flatten tokens + >>> shift_student_logits = shift_student_logits.view(-1, V) + >>> shift_teacher_logits = shift_teacher_logits.view(-1, V) + >>> shift_labels = shift_labels.view(-1) + >>> # Calculate loss + >>> loss_fct = LigerJSD(beta=0.1) + >>> loss = loss_fct(shift_studetn_logits, shift_teacher_logits, shift_labels) + ``` """ - def __init__(self, beta=0.5): + def __init__(self, beta: float = 0.5, ignore_index: int = -100): super().__init__() assert ( beta > 0 and beta < 1 ), f"beta must be greater than 0 and less than 1. Got: {beta}" self.beta = beta + self.ignore_index = ignore_index - def forward(self, log_q, log_p): - return LigerJSDFunction.apply(log_q, log_p, self.beta) + def forward( + self, + log_q: torch.Tensor, + log_p: torch.Tensor, + shift_labels: Optional[torch.LongTensor] = None, + ): + return LigerJSDFunction.apply( + log_q, log_p, shift_labels, self.beta, self.ignore_index + ) diff --git a/test/transformers/test_fused_linear_jsd.py b/test/transformers/test_fused_linear_jsd.py index e5c3b1035..321f45ab6 100644 --- a/test/transformers/test_fused_linear_jsd.py +++ b/test/transformers/test_fused_linear_jsd.py @@ -26,8 +26,9 @@ def __init__( V: int, dtype: torch.dtype, device: torch.device, - temperature: float = 1.0, beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, ): super().__init__() self.student_lin = torch.nn.Linear( @@ -36,16 +37,16 @@ def __init__( self.teacher_lin = torch.nn.Linear( in_features=H, out_features=V, bias=False, dtype=dtype, device=device ) - self.jsd = TorchJSD(beta, dtype=dtype) + self.jsd = TorchJSD(beta=beta, ignore_index=ignore_index, dtype=dtype) self.temperature = temperature - def forward(self, student_input, teacher_input): + def forward(self, student_input, teacher_input, label=None): student_logits = self.student_lin(student_input) teacher_logits = self.teacher_lin(teacher_input) student_prob = torch.log_softmax(student_logits / self.temperature, dim=-1) teacher_prob = torch.log_softmax(teacher_logits / self.temperature, dim=-1) - return self.jsd(student_prob, teacher_prob) + return self.jsd(student_prob, teacher_prob, label) class LigerLMHeadJSD(torch.nn.Module): @@ -55,8 +56,9 @@ def __init__( V: int, dtype: torch.dtype, device: torch.device, - temperature: float = 1.0, beta: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, ): super().__init__() self.student_lin = torch.nn.Linear( @@ -65,14 +67,17 @@ def __init__( self.teacher_lin = torch.nn.Linear( in_features=H, out_features=V, bias=False, dtype=dtype, device=device ) - self.fused_jsd = LigerFusedLinearJSD(beta, temperature) + self.fused_jsd = LigerFusedLinearJSD( + jsd_beta=beta, ignore_index=ignore_index, temperature=temperature + ) - def forward(self, student_input, teacher_input): + def forward(self, student_input, teacher_input, label=None): return self.fused_jsd( student_input, self.student_lin.weight, teacher_input, self.teacher_lin.weight, + label, ) @@ -139,8 +144,98 @@ def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar - output1 = torch_lm_head_jsd(_input1, teacher_input) - output2 = liger_lm_head_jsd(_input2, teacher_input) + with torch.autograd.detect_anomaly(): + output1 = torch_lm_head_jsd(_input1, teacher_input) + output2 = liger_lm_head_jsd(_input2, teacher_input) + + assert torch.allclose(output1, output2, atol=atol, rtol=rtol) + + output1.backward() + output2.backward() + + assert torch.allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol) + + assert torch.allclose( + torch_lm_head_jsd.student_lin.weight.grad, + liger_lm_head_jsd.student_lin.weight.grad, + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize( + "B, T, H, V", + [ + (2, 4, 2048, 3200), + (2, 2048, 4096, 32000), # llama2, mistral + # Comment out to speed up testing + # (4, 2048, 4096, 128256), # llama3 8B + # (4, 1024, 8192, 128256), # llama3 70B + (4, 423, 8192, 32000), # random shape + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + (1.0, torch.bfloat16, 5e-3, 5e-2), + (1.0, torch.float32, 1e-5, 5e-4), + ], +) +@pytest.mark.parametrize( + "temperature, beta, ignore_index", + [ + (1.0, 0.5, 2), + (2.0, 0.1, 42), + ], +) +def test_correctness_with_ignore_index( + B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol +): + device = "cuda" + torch_lm_head_jsd = TorchLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ).to(device) + liger_lm_head_jsd = LigerLMHeadJSD( + H=H, + V=V, + dtype=dtype, + device=device, + temperature=temperature, + beta=beta, + ).to(device) + + # init the linear in all FusedLinearJSDs with the same weights + torch_lm_head_jsd.student_lin.weight.data = ( + liger_lm_head_jsd.student_lin.weight.data + ) = torch.rand(V, H // 2, device=device, dtype=dtype) + torch_lm_head_jsd.teacher_lin.weight.data = ( + liger_lm_head_jsd.teacher_lin.weight.data + ) = torch.rand(V, H, device=device, dtype=dtype) + + _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar + _input1 = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + output1 = torch_lm_head_jsd(_input1, teacher_input, label) + output2 = liger_lm_head_jsd(_input2, teacher_input, label) assert torch.allclose(output1, output2, atol=atol, rtol=rtol) @@ -175,9 +270,11 @@ def test_correctness(B, T, H, V, scalar, dtype, beta, temperature, atol, rtol): (0.5, torch.float32, 1e-5, 5e-4), ], ) -@pytest.mark.parametrize("temperature, beta", [(1.0, 0.5), (2.0, 0.1)]) +@pytest.mark.parametrize( + "temperature, beta, ignore_index", [(1.0, 0.5, -100), (2.0, 0.1, 42)] +) def test_correctness_functional( - B, T, H, V, scalar, dtype, beta, temperature, atol, rtol + B, T, H, V, scalar, dtype, beta, ignore_index, temperature, atol, rtol ): device = "cuda" @@ -192,11 +289,36 @@ def test_correctness_functional( _input2 = _tensor.detach().clone().requires_grad_(True) teacher_input = torch.rand(B * T, H, device=device, dtype=dtype) * scalar + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + output1 = liger_fused_linear_jsd( - _input1, _weight1, teacher_input, teacher_weight, beta, temperature + _input1, + _weight1, + teacher_input, + teacher_weight, + label, + beta, + ignore_index, + temperature, ) output2 = LigerFusedLinearJSDFunction.apply( - _input2, _weight2, teacher_input, teacher_weight, beta, temperature + _input2, + _weight2, + teacher_input, + teacher_weight, + label, + beta, + ignore_index, + temperature, ) assert torch.allclose(output1, output2, atol=atol, rtol=rtol) diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index a30582041..564b85cfc 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -1,4 +1,5 @@ from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16 +from typing import Optional import pytest import torch @@ -11,23 +12,40 @@ class JSD(torch.nn.Module): - def __init__(self, beta: float = 0.5, dtype: torch.dtype = torch.float): + def __init__( + self, + beta: float = 0.5, + ignore_index: int = -100, + dtype: torch.dtype = torch.float, + ): super(JSD, self).__init__() - self.kl = KLDivLoss(reduction="batchmean", log_target=True) + self.kl = KLDivLoss(reduction="none", log_target=True) self.beta = beta + self.ignore_index = ignore_index self.dtype = dtype def forward( self, - log_q: torch.tensor, # input - log_p: torch.tensor, # target + log_q: torch.Tensor, # input + log_p: torch.Tensor, # target + label: Optional[torch.Tensor] = None, ): log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) log_p, log_q = log_p.view(-1, log_p.size(-1)), log_q.view(-1, log_q.size(-1)) m = torch.lerp(torch.exp(log_q), torch.exp(log_p), self.beta) - loss = self.beta * self.kl(torch.log(m), log_p) + (1 - self.beta) * self.kl( - torch.log(m), log_q - ) + loss = self.beta * self.kl(torch.log(m), log_p).sum(dim=-1) + ( + 1 - self.beta + ) * self.kl(torch.log(m), log_q).sum(dim=-1) + + if label is not None: + loss = torch.where(label != self.ignore_index, loss, 0.0) + n_non_ignore = (label != self.ignore_index).sum().item() + if n_non_ignore == 0: + loss = 0.0 + else: + loss = (loss / n_non_ignore).sum() + else: + loss = (loss / log_q.shape[0]).sum() return loss.to(self.dtype) @@ -148,8 +166,51 @@ def _test_correctness_with_beta_once( assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_ignore_index_once( + target_jsd, + ignore_index, + B, + T, + V, + dtype, + atol, + rtol, + device="cuda", +): + torch_jsd = JSD(ignore_index=ignore_index, dtype=dtype) + + input = torch.randn( + B * T, V, device=device, dtype=dtype, requires_grad=True + ).log_softmax(dim=-1) + + x1 = input.detach().clone().requires_grad_(True) + x2 = input.detach().clone().requires_grad_(True) + + with torch.no_grad(): + target = torch.randn(B * T, V, dtype=dtype, device=device).log_softmax(dim=-1) + + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + output = torch_jsd(x1, target, label) + output2 = target_jsd(x2, target, label) + assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert_verbose_allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) + + def _test_correctness_functional( - B, T, V, beta, is_last_layer, dtype, atol, rtol, device="cuda" + B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol, device="cuda" ): input = torch.randn( B * T, V, device=device, dtype=dtype, requires_grad=True @@ -161,8 +222,19 @@ def _test_correctness_functional( with torch.no_grad(): target = torch.randn(B * T, V, dtype=dtype, device=device).log_softmax(dim=-1) - output = LigerJSDFunction.apply(x1, target, beta) - output2 = liger_jsd(x2, target, beta) + label = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + label[indices_to_assign] = ignore_index + + output = LigerJSDFunction.apply(x1, target, label, beta, ignore_index) + output2 = liger_jsd(x2, target, label, beta, ignore_index) assert torch.allclose(output, output2, atol=atol, rtol=rtol) if ( not is_last_layer @@ -197,14 +269,28 @@ def test_correctness_with_beta(B, T, V, beta, dtype, atol, rtol): _test_correctness_with_beta_once(liger_jsd, beta, B, T, V, dtype, atol, rtol) +@pytest.mark.parametrize(*_SHAPE_PARAMS) +@pytest.mark.parametrize(*_DTYPE_PARAMS) +@pytest.mark.parametrize("ignore_index", [2, 42]) +def test_correctness_with_ignore_index(B, T, V, ignore_index, dtype, atol, rtol): + liger_jsd = LigerJSD(ignore_index=ignore_index) + _test_correctness_with_ignore_index_once( + liger_jsd, ignore_index, B, T, V, dtype, atol, rtol + ) + + @pytest.mark.parametrize(*_SHAPE_PARAMS) @pytest.mark.parametrize(*_DTYPE_PARAMS) @pytest.mark.parametrize( - "beta, is_last_layer", + "beta, ignore_index, is_last_layer", [ - (0.5, False), - (0.1, True), + (0.5, 2, False), + (0.1, 42, True), ], ) -def test_correctness_functional(B, T, V, beta, is_last_layer, dtype, atol, rtol): - _test_correctness_functional(B, T, V, beta, is_last_layer, dtype, atol, rtol) +def test_correctness_functional( + B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol +): + _test_correctness_functional( + B, T, V, beta, ignore_index, is_last_layer, dtype, atol, rtol + )