diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index b0092d5ef..d9a981912 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -27,11 +27,13 @@ def liger_cross_entropy_kernel( X_stride, Y_ptr, Y_stride, + weight_ptr, loss_ptr, z_loss_ptr, loss_stride, n_cols, n_non_ignore, + weight_sum, ignore_index, lse_square_scale: tl.constexpr, label_smoothing: tl.constexpr, @@ -39,6 +41,7 @@ def liger_cross_entropy_kernel( softcap, RETURN_Z_LOSS: tl.constexpr, BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, HAS_SOFTCAPPING: tl.constexpr, ): """ @@ -50,11 +53,14 @@ def liger_cross_entropy_kernel( X_stride (int): The stride of the input tensor. Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + weight_stride (int): The stride of the weight tensor. loss_ptr: Pointer to tensor to store the loss. z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. loss_stride (int): The stride of the loss tensor. n_cols (int): The number of columns in the input tensor. - n_non_ignore (int): The number of non-ignored elements in the batch. + n_non_ignore (flaot): The number of non-ignored elements or the sum of non-ignored target's weights in the batch + weight_sum (float): The sum of weigh tensor ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. @@ -62,6 +68,7 @@ def liger_cross_entropy_kernel( reduction (str): The string for the reduction to apply softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. """ @@ -86,6 +93,9 @@ def liger_cross_entropy_kernel( loss_ptr += program_id * loss_stride z_loss_ptr += program_id * loss_stride + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 @@ -116,7 +126,15 @@ def liger_cross_entropy_kernel( block_max = tl.max(X_block) if label_smoothing > 0: # scale X beforehand to avoid overflow - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum( + tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0) + ) + else: + scaled_x_sum += tl.sum( + tl.where(X_offsets < n_cols, -eps * X_block, 0.0) + ) m_new = tl.maximum(m, block_max) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new @@ -163,6 +181,8 @@ def liger_cross_entropy_kernel( # reduction scale if reduction == "mean": X_block = X_block / (n_non_ignore) + if HAS_WEIGHT: + X_block = X_block * weight_y # chain rule # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: @@ -182,6 +202,8 @@ def liger_cross_entropy_kernel( # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 # So we can safely calculate log (softmax(X_y)) without overflow loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) @@ -192,7 +214,10 @@ def liger_cross_entropy_kernel( # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * lse + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse loss = loss * (1 - label_smoothing) + smooth_loss # An auxiliary loss, z_loss @@ -203,7 +228,6 @@ def liger_cross_entropy_kernel( if reduction == "mean": z_loss = z_loss / n_non_ignore loss = loss / n_non_ignore - tl.store(loss_ptr, loss) if RETURN_Z_LOSS == _TRUE: tl.store(z_loss_ptr, z_loss) @@ -224,6 +248,7 @@ def liger_cross_entropy_kernel( def cross_entropy_forward( _input, target, + weight, ignore_index, lse_square_scale, label_smoothing, @@ -253,7 +278,23 @@ def cross_entropy_forward( else: z_loss_1d = loss_1d # dummy ptr when return_z_loss == False - n_non_ignore = (target != ignore_index).sum().item() + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() + weight_sum = weight.sum().item() + if weight is not None: + assert ( + weight.shape[0] == V + ), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point( + weight + ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + n_non_ignore = ( + torch.gather(weight, dim=0, index=target.masked_select(target_mask)) + .sum() + .item() + ) + if weight.stride(-1) != 1: + weight = weight.contiguous() # ensure _input and target are contiguous in the last dimension if _input.stride(-1) != 1: @@ -267,18 +308,21 @@ def cross_entropy_forward( X_stride=_input.stride(-2), Y_ptr=target, Y_stride=target.stride(-1), # always 1 + weight_ptr=weight if weight is not None else _input, # dummy if None loss_ptr=loss_1d, z_loss_ptr=z_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + weight_sum=weight_sum, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=return_z_loss, BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps @@ -330,6 +374,7 @@ def forward( ctx, _input: torch.Tensor, target: torch.Tensor, + weight: Optional[torch.FloatTensor], ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -344,6 +389,7 @@ def forward( ctx : The context object. _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype ignore_index (int): The index to ignore in the target. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. @@ -357,6 +403,7 @@ def forward( loss, z_loss, _input = cross_entropy_forward( _input, target, + weight, ignore_index, lse_square_scale, label_smoothing, @@ -398,4 +445,5 @@ def backward(ctx, grad_output, grad_ouput2): None, None, None, + None, ) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 191a2b3d2..a3d0406f1 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -83,17 +83,21 @@ def fused_linear_cross_entropy_forward( X_stride=logits_chunk.stride(-2), Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=_input, # dummy ptr, not used + weight_stride=0, loss_ptr=loss_1d_slice, z_loss_ptr=loss_1d_slice, # dummy ptr, not used loss_stride=loss_1d_slice.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, + sum_of_non_ignore_weight=n_non_ignore, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=True if softcap is not None else False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 7bd27edd6..f3e51808c 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -8,6 +8,7 @@ class LigerCrossEntropyLoss(torch.nn.Module): def __init__( self, + weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -30,6 +31,7 @@ def __init__( assert ( softcap is None or softcap > 0 ), f"softcap must greater than 0.0 or None. Got: {softcap}" + self.weight = weight self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale self.label_smoothing = label_smoothing @@ -41,6 +43,7 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor): loss, z_loss = LigerCrossEntropyFunction.apply( _input, target, + self.weight, self.ignore_index, self.lse_square_scale, self.label_smoothing, diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 45ad6159a..5d6086caa 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -34,6 +34,7 @@ def liger_cross_entropy( loss, z_loss = LigerCrossEntropyFunction.apply( input, target, + weight, ignore_index, lse_square_scale, label_smoothing, diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 5bb59d718..8c402254e 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -21,6 +21,7 @@ class CrossEntropyWithZLoss(torch.nn.Module): def __init__( self, + weight=None, lse_square_scale=0.0, reduction="mean", ignore_index=-100, @@ -29,6 +30,7 @@ def __init__( dtype=torch.float32, ): super().__init__() + self.weight = weight self.lse_square_scale = lse_square_scale self.reduction = reduction self.ignore_index = ignore_index @@ -39,10 +41,23 @@ def __init__( def forward(self, logits, targets): # Loss calculations are all in float32 logits = logits.to(torch.float32) + HAS_WEIGHT = True if self.weight is not None else False + + target_mask = targets != self.ignore_index + if HAS_WEIGHT: + self.weight = self.weight.to(torch.float32) + selected_weight = torch.where( + target_mask, + torch.gather(self.weight, dim=-1, index=targets * target_mask), + 0.0, + ) + sum_of_non_ignore_weight = selected_weight.sum().item() + # Standard cross entropy loss ce_loss = F.cross_entropy( logits, targets, + weight=self.weight, reduction=self.reduction, label_smoothing=self.label_smoothing, ignore_index=self.ignore_index, @@ -55,9 +70,18 @@ def forward(self, logits, targets): z_loss = torch.where( targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0 ) - z_loss = z_loss.to(logits.dtype) + if HAS_WEIGHT: + # print(f"{z_loss.shape=}") + z_loss = z_loss * selected_weight + # print(f"{selected_weight.shape=}") + # print(f"{selected_weight[targets == self.ignore_index]=}") + # print(f"{selected_weight[targets != self.ignore_index]=}") + if self.reduction == "mean": - z_loss = z_loss.sum() / (targets != self.ignore_index).sum() + if HAS_WEIGHT: + z_loss = z_loss.sum() / sum_of_non_ignore_weight + else: + z_loss = z_loss.sum() / (targets != self.ignore_index).sum() elif self.reduction == "sum": z_loss = z_loss.sum() else: @@ -306,6 +330,80 @@ def _test_correctness_with_z_loss_with_other_params_once( assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_weight_once( + target_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss(weight=weight, reduction=reduction) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + +def _test_correctness_with_weight_with_other_params_once( + target_ce, + B, + T, + V, + reduction, + weight, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, +): + torch.manual_seed(0) + torch_ce = CrossEntropyWithZLoss( + weight=weight, + lse_square_scale=lse_square_scale, + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + dtype=dtype, + ) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + # upcasting to match liger's casting strategy + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index + + output = torch_ce( + softcap * torch.tanh(_input.to(torch.float32) / softcap), target + ).to(dtype) + output2 = target_ce(_input2, target) + assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) + + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) + assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, reduction, scalar, dtype, atol, rtol ): @@ -350,6 +448,7 @@ def _test_correctness_functional( y1, y1_z = liger_cross_entropy( x1, target, + None, ignore_index=0, lse_square_scale=1e-4, label_smoothing=0.1, @@ -358,7 +457,7 @@ def _test_correctness_functional( return_z_loss=True, ) y2, y2_z = LigerCrossEntropyFunction.apply( - x2, target, 0, 1e-4, 0.1, "mean", 30.0, True + x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True ) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) @@ -716,6 +815,110 @@ def test_correctness_with_z_loss_with_other_params_once( (1.0, torch.float32, 1e-8, 1e-6), ], ) +def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, rtol): + weight = torch.rand(V, device=device, dtype=dtype) + test_ce = LigerCrossEntropyLoss(weight=weight, reduction=reduction) + _test_correctness_with_weight_once( + test_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol + ) + + +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 4096, 3200), # llama2, mistral + # # weird shapes + (3, 423, 3200), + ], +) +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) +@pytest.mark.parametrize( + "ignore_index, lse_square_scale, label_smoothing, softcap", + [ + (-100, 1e-4, 0.1, 30.0), + (42, 1e-5, 0.2, 40.0), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) +def test_correctness_with_weight_with_other_params_once( + B, + T, + V, + reduction, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, +): + weight = torch.rand(V, device=device, dtype=torch.float32) # match softcap casting + test_ce = LigerCrossEntropyLoss( + weight=weight, + lse_square_scale=lse_square_scale, + reduction=reduction, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + softcap=softcap, + ) + _test_correctness_with_weight_with_other_params_once( + test_ce, + B, + T, + V, + reduction, + weight, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, + ) + + +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 4096, 32000), # llama2, mistral + # # weird shapes + (3, 423, 32000), + ], +) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rtol): liger_ce = LigerCrossEntropyLoss(reduction=reduction) _test_correctness_not_last_layer_once( @@ -751,17 +954,20 @@ def test_float32_internal(): X_stride=X_bf16.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), + weight_ptr=X_bf16, # dummy ptr, not used z_loss_ptr=loss_bf16, # dummy ptr, not used loss_ptr=loss_bf16, loss_stride=loss_bf16.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, + sum_of_non_ignore_weight=n_non_ignore, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, @@ -775,17 +981,20 @@ def test_float32_internal(): X_stride=X_fp32.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), + weight_ptr=X_fp32, # dummy ptr, not used loss_ptr=loss_fp32, z_loss_ptr=loss_fp32, # dummy ptr, not used loss_stride=loss_fp32.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, + sum_of_non_ignore_weight=n_non_ignore, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, diff --git a/weight_ce.py b/weight_ce.py new file mode 100644 index 000000000..3431abe5a --- /dev/null +++ b/weight_ce.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +# Example data: 3 classes +logits = torch.tensor( + [ + [2.0, 0.5, 0.1], # Prediction logits for sample 1 + [0.1, 1.5, 2.1], # Prediction logits for sample 2 + [1.0, 2.0, 0.1], + ] +) # Prediction logits for sample 3 + +targets = torch.tensor([0, 2, 1]) # Ground truth labels + +# Define CrossEntropyLoss without weights +criterion_no_weight = nn.CrossEntropyLoss(reduction="none", label_smoothing=0.1) + +# Define CrossEntropyLoss with weights +weights = torch.tensor([0.7, 1.0, 1.5]) # Assign different weights to each class +criterion_with_weight = nn.CrossEntropyLoss( + weight=weights, reduction="none", label_smoothing=0.1 +) + +# Compute loss without weights +loss_no_weight = criterion_no_weight(logits, targets) + +# Compute loss with weights +loss_with_weight = criterion_with_weight(logits, targets) + +selected_weight = torch.gather(weights, dim=0, index=targets) +print(f"{selected_weight=}") +print("Loss without weights:", loss_no_weight) +print("Loss with weights:", loss_with_weight) +print("====================================================") +# Define CrossEntropyLoss without weights +criterion_no_weight = nn.CrossEntropyLoss(reduction="none") + +# Define CrossEntropyLoss with weights +weights = torch.tensor([0.7, 1.0, 1.5]) # Assign different weights to each class +criterion_with_weight = nn.CrossEntropyLoss(weight=weights, reduction="none") + +# Compute loss without weights +loss_no_weight = criterion_no_weight(logits, targets) + +# Compute loss with weights +loss_with_weight = criterion_with_weight(logits, targets) +print(f"{selected_weight=}") + +print("Loss without weights:", loss_no_weight) +print("Loss with weights:", loss_with_weight)