From 2d66515c0dd760b1e1dddda23e7029e67e568aaa Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Mon, 2 Dec 2024 16:46:31 +0800 Subject: [PATCH 1/7] Add weight support for LigerCrossEntropy --- src/liger_kernel/ops/cross_entropy.py | 49 +++++++++++++- .../transformers/cross_entropy.py | 3 + src/liger_kernel/transformers/functional.py | 1 + test/transformers/test_cross_entropy.py | 67 ++++++++++++++++++- 4 files changed, 116 insertions(+), 4 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 2a980c69e..41b5757d6 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -27,11 +27,14 @@ def liger_cross_entropy_kernel( X_stride, Y_ptr, Y_stride, + weight_ptr, + weight_stride, loss_ptr, z_loss_ptr, loss_stride, n_cols, n_non_ignore, + sum_of_non_ignore_weight, ignore_index, lse_square_scale: tl.constexpr, label_smoothing: tl.constexpr, @@ -39,6 +42,7 @@ def liger_cross_entropy_kernel( softcap, RETURN_Z_LOSS: tl.constexpr, BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, HAS_SOFTCAPPING: tl.constexpr, ): """ @@ -86,6 +90,9 @@ def liger_cross_entropy_kernel( loss_ptr += program_id * loss_stride z_loss_ptr += program_id * loss_stride + if HAS_WEIGHT: + weight = tl.load(weight_ptr + y).cast(tl.float32) + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 @@ -162,7 +169,12 @@ def liger_cross_entropy_kernel( X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) # reduction scale if reduction == "mean": - X_block = X_block / (n_non_ignore) + if HAS_WEIGHT: + X_block = X_block / (sum_of_non_ignore_weight) + else: + X_block = X_block / (n_non_ignore) + if HAS_WEIGHT: + X_block = X_block * weight # chain rule # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: @@ -201,8 +213,16 @@ def liger_cross_entropy_kernel( loss += z_loss # Normalize the loss by the number of non-ignored elements if reduction is "mean" if reduction == "mean": - z_loss = z_loss / n_non_ignore - loss = loss / n_non_ignore + if HAS_WEIGHT: + z_loss = z_loss / sum_of_non_ignore_weight + loss = loss / sum_of_non_ignore_weight + else: + z_loss = z_loss / n_non_ignore + loss = loss / n_non_ignore + + if HAS_WEIGHT: + z_loss = z_loss * weight + loss = loss * weight tl.store(loss_ptr, loss) if RETURN_Z_LOSS == _TRUE: @@ -224,6 +244,7 @@ def liger_cross_entropy_kernel( def cross_entropy_forward( _input, target, + weight, ignore_index, lse_square_scale, label_smoothing, @@ -254,6 +275,21 @@ def cross_entropy_forward( z_loss_1d = loss_1d # dummy ptr when return_z_loss == False n_non_ignore = (target != ignore_index).sum().item() + sum_of_non_ignore_weight = n_non_ignore + if weight is not None: + assert ( + weight.shape[0] == V + ), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}" + assert torch.is_floating_point( + weight + ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" + selected_weight = torch.gather(weight, dim=-1, index=target) + if ignore_index >= 0 and ignore_index < V: + sum_of_non_ignore_weight = selected_weight.sum().item() + else: + sum_of_non_ignore_weight = selected_weight.sum().item() + if weight.stride(-1) != 1: + weight = weight.contiguous() # ensure _input and target are contiguous in the last dimension if _input.stride(-1) != 1: @@ -267,18 +303,22 @@ def cross_entropy_forward( X_stride=_input.stride(-2), Y_ptr=target, Y_stride=target.stride(-1), # always 1 + weight_ptr=weight if weight is not None else _input, # dummy if None + weight_stride=weight.stride(-1) if weight is not None else 0, loss_ptr=loss_1d, z_loss_ptr=z_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + sum_of_non_ignore_weight=sum_of_non_ignore_weight, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=return_z_loss, BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=True if weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps @@ -329,6 +369,7 @@ def forward( ctx, _input: torch.Tensor, target: torch.Tensor, + weight: Optional[torch.FloatTensor], ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -356,6 +397,7 @@ def forward( loss, z_loss, _input = cross_entropy_forward( _input, target, + weight, ignore_index, lse_square_scale, label_smoothing, @@ -397,4 +439,5 @@ def backward(ctx, grad_output, grad_ouput2): None, None, None, + None, ) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 7bd27edd6..f3e51808c 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -8,6 +8,7 @@ class LigerCrossEntropyLoss(torch.nn.Module): def __init__( self, + weight: Optional[torch.FloatTensor] = None, ignore_index: int = -100, lse_square_scale: float = 0.0, label_smoothing: float = 0.0, @@ -30,6 +31,7 @@ def __init__( assert ( softcap is None or softcap > 0 ), f"softcap must greater than 0.0 or None. Got: {softcap}" + self.weight = weight self.ignore_index = ignore_index self.lse_square_scale = lse_square_scale self.label_smoothing = label_smoothing @@ -41,6 +43,7 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor): loss, z_loss = LigerCrossEntropyFunction.apply( _input, target, + self.weight, self.ignore_index, self.lse_square_scale, self.label_smoothing, diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 45ad6159a..5d6086caa 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -34,6 +34,7 @@ def liger_cross_entropy( loss, z_loss = LigerCrossEntropyFunction.apply( input, target, + weight, ignore_index, lse_square_scale, label_smoothing, diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 28e3ec5dc..1df09a321 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -301,6 +301,27 @@ def _test_correctness_with_z_loss_with_other_params_once( assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_weight_once( + target_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss(weight=weight, reduction=reduction) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, reduction, scalar, dtype, atol, rtol ): @@ -345,6 +366,7 @@ def _test_correctness_functional( y1, y1_z = liger_cross_entropy( x1, target, + None, ignore_index=0, lse_square_scale=1e-4, label_smoothing=0.1, @@ -353,7 +375,7 @@ def _test_correctness_functional( return_z_loss=True, ) y2, y2_z = LigerCrossEntropyFunction.apply( - x2, target, 0, 1e-4, 0.1, "mean", 30.0, True + x2, target, None, 0, 1e-4, 0.1, "mean", 30.0, True ) assert torch.allclose(y1, y2, atol=atol, rtol=rtol) @@ -687,6 +709,41 @@ def test_correctness_with_z_loss_with_other_params_once( ) +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 4096, 32000), # llama2, mistral + # # weird shapes + (3, 423, 32000), + ], +) +@pytest.mark.parametrize("weight", [0.5, 0.1]) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) +def test_correctness_with_weight_once( + B, T, V, weight, reduction, scalar, dtype, atol, rtol +): + weight = torch.rand(V, device=device, dtype=dtype) + test_ce = LigerCrossEntropyLoss(weight=weight, reduction=reduction) + _test_correctness_with_weight_once( + test_ce, B, T, V, reduction, weight, scalar, dtype, atol, rtol + ) + + @pytest.mark.parametrize( "B, T, V", [ @@ -746,17 +803,21 @@ def test_float32_internal(): X_stride=X_bf16.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), + weight_ptr=X_bf16, # dummy ptr, not used + weight_stride=X_bf16.stride(-2), z_loss_ptr=loss_bf16, # dummy ptr, not used loss_ptr=loss_bf16, loss_stride=loss_bf16.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, + sum_of_non_ignore_weight=n_non_ignore, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, @@ -770,17 +831,21 @@ def test_float32_internal(): X_stride=X_fp32.stride(-2), Y_ptr=Y, Y_stride=Y.stride(-1), + weight_ptr=X_fp32, # dummy ptr, not used + weight_stride=X_fp32.stride(-2), loss_ptr=loss_fp32, z_loss_ptr=loss_fp32, # dummy ptr, not used loss_stride=loss_fp32.stride(-1), n_cols=n_cols, n_non_ignore=n_non_ignore, + sum_of_non_ignore_weight=n_non_ignore, # not used ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, From dbe42378356ec123aba55a608a30e1502089813c Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:12:48 +0800 Subject: [PATCH 2/7] Update cross_entropy_kernel args in flce --- src/liger_kernel/ops/cross_entropy.py | 8 +++++--- src/liger_kernel/ops/fused_linear_cross_entropy.py | 4 ++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 41b5757d6..3e1782c1f 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -283,11 +283,13 @@ def cross_entropy_forward( assert torch.is_floating_point( weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" - selected_weight = torch.gather(weight, dim=-1, index=target) if ignore_index >= 0 and ignore_index < V: - sum_of_non_ignore_weight = selected_weight.sum().item() + weight_mask = torch.ones_like(weight) + weight_mask[ignore_index] = 0 + selected_weight = torch.gather(weight * weight_mask, dim=-1, index=target) else: - sum_of_non_ignore_weight = selected_weight.sum().item() + selected_weight = torch.gather(weight, dim=-1, index=target) + sum_of_non_ignore_weight = selected_weight.sum().item() if weight.stride(-1) != 1: weight = weight.contiguous() diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 191a2b3d2..a3d0406f1 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -83,17 +83,21 @@ def fused_linear_cross_entropy_forward( X_stride=logits_chunk.stride(-2), Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=_input, # dummy ptr, not used + weight_stride=0, loss_ptr=loss_1d_slice, z_loss_ptr=loss_1d_slice, # dummy ptr, not used loss_stride=loss_1d_slice.stride(-1), # always 1 n_cols=V, n_non_ignore=n_non_ignore, + sum_of_non_ignore_weight=n_non_ignore, ignore_index=ignore_index, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, softcap=softcap if softcap is not None else 0.0, RETURN_Z_LOSS=0, # False + HAS_WEIGHT=False, HAS_SOFTCAPPING=True if softcap is not None else False, BLOCK_SIZE=BLOCK_SIZE, num_warps=32 if not is_hip() else 16, From e77018258c89cbb13e1fe2006da77cfb4f64086f Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:26:34 +0800 Subject: [PATCH 3/7] Add comments --- src/liger_kernel/ops/cross_entropy.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 3e1782c1f..35644bf68 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -54,11 +54,14 @@ def liger_cross_entropy_kernel( X_stride (int): The stride of the input tensor. Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + weight_stride (int): The stride of the weight tesnor. loss_ptr: Pointer to tensor to store the loss. z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. loss_stride (int): The stride of the loss tensor. n_cols (int): The number of columns in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. + sum_of_non_ignore_weight (float): The denominator when `reduction="mean"` if `weight` is given. ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. @@ -66,6 +69,7 @@ def liger_cross_entropy_kernel( reduction (str): The string for the reduction to apply softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to dteremine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. """ @@ -386,6 +390,7 @@ def forward( ctx : The context object. _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size C and floating point dtype ignore_index (int): The index to ignore in the target. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. From f38e1e26e9bdf2d6aee772aeef2a58a2c5aa8e1c Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 3 Dec 2024 01:01:46 +0800 Subject: [PATCH 4/7] Add complete test with other params --- src/liger_kernel/ops/cross_entropy.py | 6 +- test/transformers/test_cross_entropy.py | 163 ++++++++++++++++++++++-- 2 files changed, 155 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 35644bf68..d3e834895 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -28,7 +28,6 @@ def liger_cross_entropy_kernel( Y_ptr, Y_stride, weight_ptr, - weight_stride, loss_ptr, z_loss_ptr, loss_stride, @@ -69,7 +68,7 @@ def liger_cross_entropy_kernel( reduction (str): The string for the reduction to apply softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). BLOCK_SIZE (int): The block size for Triton operations. - HAS_WEIGHT (bool): The boolean value to dteremine whether assigning weight to each of the classes. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. """ @@ -310,7 +309,6 @@ def cross_entropy_forward( Y_ptr=target, Y_stride=target.stride(-1), # always 1 weight_ptr=weight if weight is not None else _input, # dummy if None - weight_stride=weight.stride(-1) if weight is not None else 0, loss_ptr=loss_1d, z_loss_ptr=z_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 @@ -390,7 +388,7 @@ def forward( ctx : The context object. _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. - weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size C and floating point dtype + weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype ignore_index (int): The index to ignore in the target. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 1df09a321..fd8df6ea9 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -20,6 +20,7 @@ class CrossEntropyWithZLoss(torch.nn.Module): def __init__( self, + weight=None, lse_square_scale=0.0, reduction="mean", ignore_index=-100, @@ -28,6 +29,7 @@ def __init__( dtype=torch.float32, ): super().__init__() + self.weight = weight self.lse_square_scale = lse_square_scale self.reduction = reduction self.ignore_index = ignore_index @@ -38,10 +40,24 @@ def __init__( def forward(self, logits, targets): # Loss calculations are all in float32 logits = logits.to(torch.float32) + HAS_WEIGHT = True if self.weight is not None else False + if HAS_WEIGHT: + self.weight = self.weight.to(torch.float32) + if self.ignore_index >= 0 and self.ignore_index < logits.shape[-1]: + weight_mask = torch.ones_like(self.weight) + weight_mask[self.ignore_index] = 0 + selected_weight = torch.gather( + self.weight * weight_mask, dim=-1, index=targets + ) + del weight_mask + else: + selected_weight = torch.gather(self.weight, dim=-1, index=targets) + sum_of_non_ignore_weight = selected_weight.sum().item() # Standard cross entropy loss ce_loss = F.cross_entropy( logits, targets, + weight=self.weight, reduction=self.reduction, label_smoothing=self.label_smoothing, ignore_index=self.ignore_index, @@ -54,9 +70,14 @@ def forward(self, logits, targets): z_loss = torch.where( targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0 ) - z_loss = z_loss.to(logits.dtype) + if HAS_WEIGHT: + z_loss = z_loss * selected_weight + if self.reduction == "mean": - z_loss = z_loss.sum() / (targets != self.ignore_index).sum() + if HAS_WEIGHT: + z_loss = z_loss.sum() / sum_of_non_ignore_weight + else: + z_loss = z_loss.sum() / (targets != self.ignore_index).sum() elif self.reduction == "sum": z_loss = z_loss.sum() else: @@ -185,13 +206,15 @@ def _test_correctness_with_softcap_once( _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar # upcasting to match liger's casting strategy - _input = _tensor.to(torch.float32).detach().clone().requires_grad_(True) + _input = _tensor.detach().clone().requires_grad_(True) _input2 = _tensor.detach().clone().requires_grad_(True) target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) # downcasting to original dtype - output = torch_ce(softcap * torch.tanh(_input / softcap), target).to(dtype) + output = torch_ce( + softcap * torch.tanh(_input.to(torch.float32) / softcap), target + ).to(dtype) output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) @@ -322,6 +345,59 @@ def _test_correctness_with_weight_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_weight_with_other_params_once( + target_ce, + B, + T, + V, + reduction, + weight, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, +): + torch.manual_seed(0) + torch_ce = CrossEntropyWithZLoss( + weight=weight, + lse_square_scale=lse_square_scale, + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + dtype=dtype, + ) + + _tensor = torch.randn(B * T, V, device=device, dtype=dtype) * scalar + # upcasting to match liger's casting strategy + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device=device, dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index + + output = torch_ce( + softcap * torch.tanh(_input.to(torch.float32) / softcap), target + ).to(dtype) + output2 = target_ce(_input2, target) + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, reduction, scalar, dtype, atol, rtol ): @@ -717,7 +793,6 @@ def test_correctness_with_z_loss_with_other_params_once( (3, 423, 32000), ], ) -@pytest.mark.parametrize("weight", [0.5, 0.1]) @pytest.mark.parametrize("reduction", ["sum", "mean"]) @pytest.mark.parametrize( "scalar, dtype, atol, rtol", @@ -734,9 +809,7 @@ def test_correctness_with_z_loss_with_other_params_once( (1.0, torch.float32, 1e-8, 1e-6), ], ) -def test_correctness_with_weight_once( - B, T, V, weight, reduction, scalar, dtype, atol, rtol -): +def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, rtol): weight = torch.rand(V, device=device, dtype=dtype) test_ce = LigerCrossEntropyLoss(weight=weight, reduction=reduction) _test_correctness_with_weight_once( @@ -744,6 +817,78 @@ def test_correctness_with_weight_once( ) +@pytest.mark.parametrize( + "B, T, V", + [ + (2, 4096, 3200), # llama2, mistral + # # weird shapes + (3, 423, 3200), + ], +) +@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize( + "ignore_index, lse_square_scale, label_smoothing, softcap", + [ + (-100, 1e-4, 0.1, 30.0), + (42, 1e-5, 0.2, 40.0), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (1.0, torch.float32, 1e-8, 1e-6), + ], +) +def test_correctness_with_weight_with_other_params_once( + B, + T, + V, + reduction, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, +): + weight = torch.rand(V, device=device, dtype=torch.float32) # match softcap casting + test_ce = LigerCrossEntropyLoss( + weight=weight, + lse_square_scale=lse_square_scale, + reduction=reduction, + ignore_index=ignore_index, + label_smoothing=label_smoothing, + softcap=softcap, + ) + _test_correctness_with_weight_with_other_params_once( + test_ce, + B, + T, + V, + reduction, + weight, + lse_square_scale, + ignore_index, + label_smoothing, + softcap, + scalar, + dtype, + atol, + rtol, + ) + + @pytest.mark.parametrize( "B, T, V", [ @@ -804,7 +949,6 @@ def test_float32_internal(): Y_ptr=Y, Y_stride=Y.stride(-1), weight_ptr=X_bf16, # dummy ptr, not used - weight_stride=X_bf16.stride(-2), z_loss_ptr=loss_bf16, # dummy ptr, not used loss_ptr=loss_bf16, loss_stride=loss_bf16.stride(-1), @@ -832,7 +976,6 @@ def test_float32_internal(): Y_ptr=Y, Y_stride=Y.stride(-1), weight_ptr=X_fp32, # dummy ptr, not used - weight_stride=X_fp32.stride(-2), loss_ptr=loss_fp32, z_loss_ptr=loss_fp32, # dummy ptr, not used loss_stride=loss_fp32.stride(-1), From 45f6c1f9db0938a3d380a86452e3d237056ccbda Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Tue, 3 Dec 2024 02:07:08 +0800 Subject: [PATCH 5/7] Fix invalid range access bug --- src/liger_kernel/ops/cross_entropy.py | 12 +++++------- test/transformers/test_cross_entropy.py | 21 ++++++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index d3e834895..ea67ba6cd 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -277,7 +277,8 @@ def cross_entropy_forward( else: z_loss_1d = loss_1d # dummy ptr when return_z_loss == False - n_non_ignore = (target != ignore_index).sum().item() + target_mask = target != ignore_index + n_non_ignore = target_mask.sum().item() sum_of_non_ignore_weight = n_non_ignore if weight is not None: assert ( @@ -286,12 +287,9 @@ def cross_entropy_forward( assert torch.is_floating_point( weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" - if ignore_index >= 0 and ignore_index < V: - weight_mask = torch.ones_like(weight) - weight_mask[ignore_index] = 0 - selected_weight = torch.gather(weight * weight_mask, dim=-1, index=target) - else: - selected_weight = torch.gather(weight, dim=-1, index=target) + selected_weight = torch.where( + target_mask, torch.gather(weight, dim=0, index=target * target_mask), 0.0 + ) sum_of_non_ignore_weight = selected_weight.sum().item() if weight.stride(-1) != 1: weight = weight.contiguous() diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index fd8df6ea9..ff8302569 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -41,18 +41,17 @@ def forward(self, logits, targets): # Loss calculations are all in float32 logits = logits.to(torch.float32) HAS_WEIGHT = True if self.weight is not None else False + + target_mask = targets != self.ignore_index if HAS_WEIGHT: self.weight = self.weight.to(torch.float32) - if self.ignore_index >= 0 and self.ignore_index < logits.shape[-1]: - weight_mask = torch.ones_like(self.weight) - weight_mask[self.ignore_index] = 0 - selected_weight = torch.gather( - self.weight * weight_mask, dim=-1, index=targets - ) - del weight_mask - else: - selected_weight = torch.gather(self.weight, dim=-1, index=targets) + selected_weight = torch.where( + target_mask, + torch.gather(self.weight, dim=-1, index=targets * target_mask), + 0.0, + ) sum_of_non_ignore_weight = selected_weight.sum().item() + # Standard cross entropy loss ce_loss = F.cross_entropy( logits, @@ -71,7 +70,11 @@ def forward(self, logits, targets): targets != self.ignore_index, self.lse_square_scale * (lse**2), 0.0 ) if HAS_WEIGHT: + # print(f"{z_loss.shape=}") z_loss = z_loss * selected_weight + # print(f"{selected_weight.shape=}") + # print(f"{selected_weight[targets == self.ignore_index]=}") + # print(f"{selected_weight[targets != self.ignore_index]=}") if self.reduction == "mean": if HAS_WEIGHT: From a1a4f0ac85a831a8b00b6d3ff984b962a15c92b9 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sun, 15 Dec 2024 13:39:34 +0800 Subject: [PATCH 6/7] Refactor variable names and computation of target's weights --- src/liger_kernel/ops/cross_entropy.py | 58 ++++++++++++++------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index ea67ba6cd..1998e7831 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -33,7 +33,7 @@ def liger_cross_entropy_kernel( loss_stride, n_cols, n_non_ignore, - sum_of_non_ignore_weight, + weight_sum, ignore_index, lse_square_scale: tl.constexpr, label_smoothing: tl.constexpr, @@ -54,13 +54,13 @@ def liger_cross_entropy_kernel( Y_ptr: Pointer to target tensor. Y_stride (int): The stride of the target tensor. weight_ptr: Pointer to weight tensor. - weight_stride (int): The stride of the weight tesnor. + weight_stride (int): The stride of the weight tensor. loss_ptr: Pointer to tensor to store the loss. z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. loss_stride (int): The stride of the loss tensor. n_cols (int): The number of columns in the input tensor. - n_non_ignore (int): The number of non-ignored elements in the batch. - sum_of_non_ignore_weight (float): The denominator when `reduction="mean"` if `weight` is given. + n_non_ignore (flaot): The number of non-ignored elements or the sum of non-ignored target's weights in the batch + weight_sum (float): The sum of weigh tensor ignore_index (int): The index to ignore in the target. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. @@ -94,7 +94,7 @@ def liger_cross_entropy_kernel( z_loss_ptr += program_id * loss_stride if HAS_WEIGHT: - weight = tl.load(weight_ptr + y).cast(tl.float32) + weight_y = tl.load(weight_ptr + y).cast(tl.float32) # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 @@ -126,7 +126,15 @@ def liger_cross_entropy_kernel( block_max = tl.max(X_block) if label_smoothing > 0: # scale X beforehand to avoid overflow - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum( + tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0) + ) + else: + scaled_x_sum += tl.sum( + tl.where(X_offsets < n_cols, -eps * X_block, 0.0) + ) m_new = tl.maximum(m, block_max) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new @@ -172,12 +180,9 @@ def liger_cross_entropy_kernel( X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) # reduction scale if reduction == "mean": - if HAS_WEIGHT: - X_block = X_block / (sum_of_non_ignore_weight) - else: - X_block = X_block / (n_non_ignore) + X_block = X_block / (n_non_ignore) if HAS_WEIGHT: - X_block = X_block * weight + X_block = X_block * weight_y # chain rule # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) if HAS_SOFTCAPPING: @@ -197,6 +202,8 @@ def liger_cross_entropy_kernel( # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 # So we can safely calculate log (softmax(X_y)) without overflow loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) @@ -207,7 +214,10 @@ def liger_cross_entropy_kernel( # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * lse + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse loss = loss * (1 - label_smoothing) + smooth_loss # An auxiliary loss, z_loss @@ -216,17 +226,8 @@ def liger_cross_entropy_kernel( loss += z_loss # Normalize the loss by the number of non-ignored elements if reduction is "mean" if reduction == "mean": - if HAS_WEIGHT: - z_loss = z_loss / sum_of_non_ignore_weight - loss = loss / sum_of_non_ignore_weight - else: - z_loss = z_loss / n_non_ignore - loss = loss / n_non_ignore - - if HAS_WEIGHT: - z_loss = z_loss * weight - loss = loss * weight - + z_loss = z_loss / n_non_ignore + loss = loss / n_non_ignore tl.store(loss_ptr, loss) if RETURN_Z_LOSS == _TRUE: tl.store(z_loss_ptr, z_loss) @@ -279,7 +280,7 @@ def cross_entropy_forward( target_mask = target != ignore_index n_non_ignore = target_mask.sum().item() - sum_of_non_ignore_weight = n_non_ignore + weight_sum = weight.sum().item() if weight is not None: assert ( weight.shape[0] == V @@ -287,10 +288,11 @@ def cross_entropy_forward( assert torch.is_floating_point( weight ), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}" - selected_weight = torch.where( - target_mask, torch.gather(weight, dim=0, index=target * target_mask), 0.0 + n_non_ignore = ( + torch.gather(weight, dim=0, index=target.masked_select(target_mask)) + .sum() + .item() ) - sum_of_non_ignore_weight = selected_weight.sum().item() if weight.stride(-1) != 1: weight = weight.contiguous() @@ -313,7 +315,7 @@ def cross_entropy_forward( n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, - sum_of_non_ignore_weight=sum_of_non_ignore_weight, + weight_sum=weight_sum, lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, From 2e6ded2f8f0b01291673a39c90f1fa82c3ce0656 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sun, 15 Dec 2024 13:50:29 +0800 Subject: [PATCH 7/7] Fix unit test --- test/transformers/test_cross_entropy.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index f3113aaae..8c402254e 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -346,8 +346,8 @@ def _test_correctness_with_weight_once( output2 = target_ce(_input2, target) assert torch.allclose(output, output2, atol=atol, rtol=rtol) - output.backward() - output2.backward() + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) @@ -397,11 +397,11 @@ def _test_correctness_with_weight_with_other_params_once( softcap * torch.tanh(_input.to(torch.float32) / softcap), target ).to(dtype) output2 = target_ce(_input2, target) - assert torch.allclose(output, output2, atol=atol, rtol=rtol) + assert_verbose_allclose(output, output2, atol=atol, rtol=rtol) - output.backward() - output2.backward() - assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + output.backward(gradient=torch.ones_like(output)) + output2.backward(gradient=torch.ones_like(output)) + assert_verbose_allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) def _test_correctness_not_last_layer_once( @@ -831,7 +831,7 @@ def test_correctness_with_weight_once(B, T, V, reduction, scalar, dtype, atol, r (3, 423, 3200), ], ) -@pytest.mark.parametrize("reduction", ["sum", "mean"]) +@pytest.mark.parametrize("reduction", ["sum", "mean", "none"]) @pytest.mark.parametrize( "ignore_index, lse_square_scale, label_smoothing, softcap", [