diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 32703788c..901809d4d 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -14,6 +14,7 @@ def liger_cross_entropy_kernel( n_cols, n_non_ignore, ignore_index, + label_smoothing: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ @@ -30,6 +31,7 @@ def liger_cross_entropy_kernel( n_cols (int): The number of columns in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. BLOCK_SIZE (int): The block size for Triton operations. """ @@ -63,12 +65,20 @@ def liger_cross_entropy_kernel( X_ptr + y ) # we need to store the original value of X_y for the loss calculation + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) m_new = tl.maximum(m, block_max) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new @@ -77,12 +87,16 @@ def liger_cross_entropy_kernel( # dx_y = (softmax(x_y) - 1) / N # dx_i = softmax(x_i) / N, i != y # N is the number of non ignored elements in the batch + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) - X_block = (tl.exp(X_block - m) / d) / (n_non_ignore) + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in @@ -97,9 +111,21 @@ def liger_cross_entropy_kernel( # So we can safely calculate log (softmax(X_y)) without overflow loss = -(ori_X_y - m - tl.log(d)) - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N` + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + loss = loss * (1 - label_smoothing) + smooth_loss + + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` X_y = tl.load(X_ptr + y) - X_y += -1 / (n_non_ignore) + X_y += -(1 - label_smoothing) / (n_non_ignore) tl.store(loss_ptr, loss) tl.store(X_ptr + y, X_y) @@ -147,7 +173,7 @@ def element_mul_kernel( tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) -def cross_entropy_forward(_input, target, ignore_index): +def cross_entropy_forward(_input, target, ignore_index, label_smoothing): BT, V = _input.shape n_rows = BT @@ -175,6 +201,7 @@ def cross_entropy_forward(_input, target, ignore_index): n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps @@ -216,7 +243,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, _input, target, ignore_index): + def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0): """ The forward pass of the Liger Cross Entropy loss. @@ -225,11 +252,14 @@ def forward(ctx, _input, target, ignore_index): _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. Returns: tensor: The computed loss. """ - loss, _input = cross_entropy_forward(_input, target, ignore_index) + loss, _input = cross_entropy_forward( + _input, target, ignore_index, label_smoothing + ) # TODO: investigation # If we don't detach the _input tensor, the memory will double # Not sure why but seems that there will be a time both grad and value exist but in different location @@ -254,4 +284,5 @@ def backward(ctx, grad_output): _input, None, None, + None, ) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 7b62dbbb1..bf0e3da48 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -13,7 +13,7 @@ def fused_linear_cross_entropy_forward( - _input, weight, target, bias=None, ignore_index=-100 + _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0 ): dtype = ( torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype @@ -80,6 +80,7 @@ def fused_linear_cross_entropy_forward( n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, ) @@ -171,7 +172,9 @@ def fused_linear_cross_entropy_backward( class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): + def forward( + ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0 + ): """ Fusing the last linear layer with cross-entropy loss Reference: https://github.com/mgmalek/efficient_cross_entropy @@ -188,7 +191,7 @@ def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): ignore_index: the index to ignore in the target """ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input, weight, target, bias, ignore_index + _input, weight, target, bias, ignore_index, label_smoothing ) # downcast to dtype and store for backward ctx.save_for_backward( diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 44255eb85..0adb1cc87 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -6,6 +6,11 @@ class LigerCrossEntropyLoss(CrossEntropyLoss): def __init__(self, *args, **kwargs): super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs) + assert (self.label_smoothing >= 0) and ( + self.label_smoothing <= 1 + ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" def forward(self, _input, target): - return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index) + return LigerCrossEntropyFunction.apply( + _input, target, self.ignore_index, self.label_smoothing + ) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index c1fcccb34..cdc7000e5 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -61,6 +61,61 @@ def _test_correctness_with_ignore_index_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_label_smoothing_once( + target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing) + + _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + +def _test_correctness_with_label_smoothing_with_ignore_index_once( + target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss( + ignore_index=ignore_index, label_smoothing=label_smoothing + ) + + _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, scalar, dtype, atol, rtol ): @@ -248,6 +303,125 @@ def test_correctness_with_ignore_index( ) +@pytest.mark.parametrize( + "B, T, V, label_smoothing", + [ + (2, 4096, 32000, 0.1), # llama2, mistral + (2, 4096, 32000, 0.1), # llama2, mistral + (1, 4096, 128256, 0.1), # llama3 + # weird shapes + (3, 423, 32000, 0.1), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 0.1, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 10.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (0.1, torch.float32, 1e-8, 1e-6), + (1.0, torch.float32, 1e-8, 1e-6), + (10.0, torch.float32, 1e-8, 1e-6), + ], +) +@pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, + reason="Needs 16GB+ GPU memory.", +) +def test_correctness_with_label_smoothing_once( + B, T, V, label_smoothing, scalar, dtype, atol, rtol +): + liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing) + _test_correctness_with_label_smoothing_once( + liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol + ) + + +@pytest.mark.parametrize( + "B, T, V, ignore_index, label_smoothing", + [ + (2, 4096, 32000, 1, 0.1), # llama2, mistral + (2, 4096, 32000, -100, 0.2), # llama2, mistral + (1, 4096, 128256, 2, 0.1), # llama3 + # weird shapes + (3, 423, 32000, -300, 0.2), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 0.1, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 10.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (0.1, torch.float32, 1e-8, 1e-6), + (1.0, torch.float32, 1e-8, 1e-6), + (10.0, torch.float32, 1e-8, 1e-6), + ], +) +@pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, + reason="Needs 16GB+ GPU memory.", +) +def test_correctness_with_label_smoothing_with_ignore_index_once( + B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol +): + liger_ce = LigerCrossEntropyLoss( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) + _test_correctness_with_label_smoothing_with_ignore_index_once( + liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol + ) + + @pytest.mark.parametrize( "B, T, V", [