From f8e1b5479ab13da4f7dd9fc290c2b6a81c0d867f Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Wed, 4 Sep 2024 03:04:12 +0800 Subject: [PATCH 01/10] Add label_smoothing parameter for cross entropy --- src/liger_kernel/ops/cross_entropy.py | 35 +++- .../transformers/cross_entropy.py | 4 +- test/transformers/test_cross_entropy.py | 177 ++++++++++++++++++ 3 files changed, 210 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 32703788c..9cebeb908 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -14,6 +14,7 @@ def liger_cross_entropy_kernel( n_cols, n_non_ignore, ignore_index, + label_smoothing: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ @@ -30,6 +31,7 @@ def liger_cross_entropy_kernel( n_cols (int): The number of columns in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. ignore_index (int): The index to ignore in the target. + label_smoothing (float): for label smoothing BLOCK_SIZE (int): The block size for Triton operations. """ @@ -73,16 +75,30 @@ def liger_cross_entropy_kernel( d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new + # we need to compute sum(softmax(x_i)) for smooth_loss + smooth_loss = 0.0 + if label_smoothing > 0: + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, mask=X_offsets < n_cols, other=(m + tl.log(d)) + ) + smooth_loss += -tl.sum(X_block - m - tl.log(d)) + # 4. [Online softmax] second pass: calculate the gradients # dx_y = (softmax(x_y) - 1) / N # dx_i = softmax(x_i) / N, i != y # N is the number of non ignored elements in the batch + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + eps = label_smoothing / n_cols for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) - X_block = (tl.exp(X_block - m) / d) / (n_non_ignore) + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in @@ -96,10 +112,14 @@ def liger_cross_entropy_kernel( # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 # So we can safely calculate log (softmax(X_y)) without overflow loss = -(ori_X_y - m - tl.log(d)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + if label_smoothing > 0: + loss = loss * (1 - label_smoothing) + smooth_loss * eps # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N` X_y = tl.load(X_ptr + y) - X_y += -1 / (n_non_ignore) + X_y += -(1 - label_smoothing) / (n_non_ignore) tl.store(loss_ptr, loss) tl.store(X_ptr + y, X_y) @@ -147,7 +167,7 @@ def element_mul_kernel( tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) -def cross_entropy_forward(_input, target, ignore_index): +def cross_entropy_forward(_input, target, ignore_index, label_smoothing): BT, V = _input.shape n_rows = BT @@ -175,6 +195,7 @@ def cross_entropy_forward(_input, target, ignore_index): n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps @@ -216,7 +237,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, _input, target, ignore_index): + def forward(ctx, _input, target, ignore_index, label_smoothing=0.0): """ The forward pass of the Liger Cross Entropy loss. @@ -225,11 +246,14 @@ def forward(ctx, _input, target, ignore_index): _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. ignore_index (int): The index to ignore in the target. + label_smoothing (float): A float in [0.0, 1.0]. Returns: tensor: The computed loss. """ - loss, _input = cross_entropy_forward(_input, target, ignore_index) + loss, _input = cross_entropy_forward( + _input, target, ignore_index, label_smoothing + ) # TODO: investigation # If we don't detach the _input tensor, the memory will double # Not sure why but seems that there will be a time both grad and value exist but in different location @@ -254,4 +278,5 @@ def backward(ctx, grad_output): _input, None, None, + None, ) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 44255eb85..1df0fdf7f 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -8,4 +8,6 @@ def __init__(self, *args, **kwargs): super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs) def forward(self, _input, target): - return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index) + return LigerCrossEntropyFunction.apply( + _input, target, self.ignore_index, self.label_smoothing + ) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index c1fcccb34..2fa5a9554 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -61,6 +61,61 @@ def _test_correctness_with_ignore_index_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_label_smoothing_once( + target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing) + + _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + +def _test_correctness_with_label_smoothing_and_ignore_index_once( + target_ce, B, T, V, label_smoothing, ignore_index, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss( + ignore_index=ignore_index, label_smoothing=label_smoothing + ) + + _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, scalar, dtype, atol, rtol ): @@ -248,6 +303,128 @@ def test_correctness_with_ignore_index( ) +@pytest.mark.parametrize( + "B, T, V, label_smoothing", + [ + (2, 4096, 32000, 0.1), # llama2, mistral + (2, 4096, 32000, 0.1), # llama2, mistral + (1, 4096, 128256, 0.1), # llama3 + # weird shapes + (3, 423, 32000, 0.1), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 0.1, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 10.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (0.1, torch.float32, 1e-8, 1e-6), + (1.0, torch.float32, 1e-8, 1e-6), + (10.0, torch.float32, 1e-8, 1e-6), + ], +) +@pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, + reason="Needs 16GB+ GPU memory.", +) +def test_correctness_with_label_smoothing_once( + B, T, V, label_smoothing, scalar, dtype, atol, rtol +): + liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing) + _test_correctness_with_label_smoothing_once( + liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol + ) + + +@pytest.mark.parametrize( + "B, T, V, label_smoothing, ignore_index", + [ + (2, 4096, 32000, 0.1, 2), # llama2, mistral + (2, 4096, 32000, 0.0, 2), # llama2, mistral + (2, 4096, 32000, 0.1, -100), # llama2, mistral + (2, 4096, 32000, 0.0, -100), # llama2, mistral + # (1, 4096, 128256, 0.1), # llama3 + # weird shapes + (3, 423, 32000, 0.1, 2), + (3, 423, 32000, 0.0, 2), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 0.1, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 10.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (0.1, torch.float32, 1e-8, 1e-6), + (1.0, torch.float32, 1e-8, 1e-6), + (10.0, torch.float32, 1e-8, 1e-6), + ], +) +@pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, + reason="Needs 16GB+ GPU memory.", +) +def test_correctness_with_label_smoothing_and_ignore_index_once( + B, T, V, label_smoothing, ignore_index, scalar, dtype, atol, rtol +): + liger_ce = LigerCrossEntropyLoss( + label_smoothing=label_smoothing, ignore_index=ignore_index + ) + + _test_correctness_with_label_smoothing_and_ignore_index_once( + liger_ce, B, T, V, label_smoothing, ignore_index, scalar, dtype, atol, rtol + ) + + @pytest.mark.parametrize( "B, T, V", [ From a92a57b6c4510d6d6c95e3f5fbda244a7958081a Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Wed, 4 Sep 2024 03:28:14 +0800 Subject: [PATCH 02/10] Add the missing argument in flce function --- .../ops/fused_linear_cross_entropy.py | 3 +- test/transformers/test_cross_entropy.py | 97 ------------------- 2 files changed, 2 insertions(+), 98 deletions(-) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 7b62dbbb1..df210c5be 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -13,7 +13,7 @@ def fused_linear_cross_entropy_forward( - _input, weight, target, bias=None, ignore_index=-100 + _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0 ): dtype = ( torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype @@ -80,6 +80,7 @@ def fused_linear_cross_entropy_forward( n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, ) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 2fa5a9554..b99f2b099 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -83,39 +83,6 @@ def _test_correctness_with_label_smoothing_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_with_label_smoothing_and_ignore_index_once( - target_ce, B, T, V, label_smoothing, ignore_index, scalar, dtype, atol, rtol -): - torch.manual_seed(0) - torch_ce = CrossEntropyLoss( - ignore_index=ignore_index, label_smoothing=label_smoothing - ) - - _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar - _input = _tensor.detach().clone().requires_grad_(True) - _input2 = _tensor.detach().clone().requires_grad_(True) - - target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) - - # Assign some random number of elements as ignore_index - num_elements_to_assign = torch.randint( - 1, B * T // 2, (1,) - ).item() # Random number of elements to set to ignore_index - indices_to_assign = torch.randperm(B * T)[ - :num_elements_to_assign - ] # Randomly select indices - target[indices_to_assign] = ignore_index - - output = torch_ce(_input, target) - output2 = target_ce(_input2, target) - - assert torch.allclose(output, output2, atol=atol, rtol=rtol) - - output.backward() - output2.backward() - assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) - - def _test_correctness_not_last_layer_once( target_ce, B, T, V, scalar, dtype, atol, rtol ): @@ -361,70 +328,6 @@ def test_correctness_with_label_smoothing_once( ) -@pytest.mark.parametrize( - "B, T, V, label_smoothing, ignore_index", - [ - (2, 4096, 32000, 0.1, 2), # llama2, mistral - (2, 4096, 32000, 0.0, 2), # llama2, mistral - (2, 4096, 32000, 0.1, -100), # llama2, mistral - (2, 4096, 32000, 0.0, -100), # llama2, mistral - # (1, 4096, 128256, 0.1), # llama3 - # weird shapes - (3, 423, 32000, 0.1, 2), - (3, 423, 32000, 0.0, 2), - ], -) -@pytest.mark.parametrize( - "scalar, dtype, atol, rtol", - [ - pytest.param( - 0.1, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - pytest.param( - 1.0, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - pytest.param( - 10.0, - torch.bfloat16, - 1e-8, - 5e-2, - marks=pytest.mark.skipif( - not supports_bfloat16(), reason="bfloat16 not supported on this GPU" - ), - ), - (0.1, torch.float32, 1e-8, 1e-6), - (1.0, torch.float32, 1e-8, 1e-6), - (10.0, torch.float32, 1e-8, 1e-6), - ], -) -@pytest.mark.skipif( - torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, - reason="Needs 16GB+ GPU memory.", -) -def test_correctness_with_label_smoothing_and_ignore_index_once( - B, T, V, label_smoothing, ignore_index, scalar, dtype, atol, rtol -): - liger_ce = LigerCrossEntropyLoss( - label_smoothing=label_smoothing, ignore_index=ignore_index - ) - - _test_correctness_with_label_smoothing_and_ignore_index_once( - liger_ce, B, T, V, label_smoothing, ignore_index, scalar, dtype, atol, rtol - ) - - @pytest.mark.parametrize( "B, T, V", [ From 7388bbc7e0694bf7128c3594f5b5d6258d93bf2a Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Wed, 4 Sep 2024 04:27:29 +0800 Subject: [PATCH 03/10] Improve comment readability --- src/liger_kernel/ops/cross_entropy.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 9cebeb908..0c47cbed9 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -31,7 +31,7 @@ def liger_cross_entropy_kernel( n_cols (int): The number of columns in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. ignore_index (int): The index to ignore in the target. - label_smoothing (float): for label smoothing + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. BLOCK_SIZE (int): The block size for Triton operations. """ @@ -81,7 +81,11 @@ def liger_cross_entropy_kernel( for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( - X_ptr + X_offsets, mask=X_offsets < n_cols, other=(m + tl.log(d)) + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=( + m + tl.log(d) + ), # out-of-bounds will become 0 after calculating softmax ) smooth_loss += -tl.sum(X_block - m - tl.log(d)) @@ -89,7 +93,8 @@ def liger_cross_entropy_kernel( # dx_y = (softmax(x_y) - 1) / N # dx_i = softmax(x_i) / N, i != y # N is the number of non ignored elements in the batch - # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N # = dx_i - (1 - label_smoothing) / N eps = label_smoothing / n_cols @@ -112,6 +117,9 @@ def liger_cross_entropy_kernel( # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 # So we can safely calculate log (softmax(X_y)) without overflow loss = -(ori_X_y - m - tl.log(d)) + + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + (label_smoothing / V) * sum(softmax(x_i)) # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 if label_smoothing > 0: From 56f1ed7d0a3d0e9a8f5ba39c9230c946bf6ee56c Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Wed, 4 Sep 2024 21:45:27 +0800 Subject: [PATCH 04/10] Fix a equation mistake --- src/liger_kernel/ops/cross_entropy.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 0c47cbed9..780435a1f 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -75,8 +75,9 @@ def liger_cross_entropy_kernel( d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new - # we need to compute sum(softmax(x_i)) for smooth_loss + # we need to compute sum(log(softmax(x_i))) for smooth_loss smooth_loss = 0.0 + log_d = tl.log(d) # avoid redundant calculations if label_smoothing > 0: for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) @@ -84,10 +85,10 @@ def liger_cross_entropy_kernel( X_ptr + X_offsets, mask=X_offsets < n_cols, other=( - m + tl.log(d) + m + log_d ), # out-of-bounds will become 0 after calculating softmax ) - smooth_loss += -tl.sum(X_block - m - tl.log(d)) + smooth_loss += -tl.sum(X_block - m - log_d) # 4. [Online softmax] second pass: calculate the gradients # dx_y = (softmax(x_y) - 1) / N @@ -116,8 +117,9 @@ def liger_cross_entropy_kernel( # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 # So we can safely calculate log (softmax(X_y)) without overflow - loss = -(ori_X_y - m - tl.log(d)) + loss = -(ori_X_y - m - log_d) + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) # = (1 - label_smoothing) * H(q, p) + (label_smoothing / V) * sum(softmax(x_i)) # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 From 43d4ae29a5f8fdef2ce666400bea9d8e6e417f3e Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Thu, 5 Sep 2024 05:00:39 +0800 Subject: [PATCH 05/10] Improve the parameter comment for open API --- src/liger_kernel/ops/cross_entropy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 780435a1f..fed6eb13b 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -256,7 +256,7 @@ def forward(ctx, _input, target, ignore_index, label_smoothing=0.0): _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. ignore_index (int): The index to ignore in the target. - label_smoothing (float): A float in [0.0, 1.0]. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. Returns: tensor: The computed loss. From a9361e00e6833b41dcc3428b6d75b7fea5ea244d Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Thu, 5 Sep 2024 20:30:53 +0800 Subject: [PATCH 06/10] Add a unit test --- test/transformers/test_cross_entropy.py | 94 +++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index b99f2b099..6bcb6ce32 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -83,6 +83,39 @@ def _test_correctness_with_label_smoothing_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_parameters_once( + target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss( + ignore_index=ignore_index, label_smoothing=label_smoothing + ) + + _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, scalar, dtype, atol, rtol ): @@ -328,6 +361,67 @@ def test_correctness_with_label_smoothing_once( ) +@pytest.mark.parametrize( + "B, T, V, ignore_index, label_smoothing", + [ + (2, 4096, 32000, 1, 0.1), # llama2, mistral + (2, 4096, 32000, -100, 0.2), # llama2, mistral + (1, 4096, 128256, 2, 0.1), # llama3 + # weird shapes + (3, 423, 32000, -300, 0.2), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 0.1, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 10.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (0.1, torch.float32, 1e-8, 1e-6), + (1.0, torch.float32, 1e-8, 1e-6), + (10.0, torch.float32, 1e-8, 1e-6), + ], +) +@pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, + reason="Needs 16GB+ GPU memory.", +) +def test_correctness_with_parameters_once( + B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol +): + liger_ce = LigerCrossEntropyLoss( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) + _test_correctness_with_parameters_once( + liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol + ) + + @pytest.mark.parametrize( "B, T, V", [ From 5a0fba7d3fdaa5dc45d7195fbc1d7a7cbdaaf399 Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Fri, 6 Sep 2024 18:36:41 +0800 Subject: [PATCH 07/10] Modify label_smoothing impelmentation to improve time efficiency --- src/liger_kernel/ops/cross_entropy.py | 36 +++++++++++---------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index fed6eb13b..ea5d7f3c1 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -65,31 +65,22 @@ def liger_cross_entropy_kernel( X_ptr + y ) # we need to store the original value of X_y for the loss calculation + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) m_new = tl.maximum(m, block_max) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new - # we need to compute sum(log(softmax(x_i))) for smooth_loss - smooth_loss = 0.0 - log_d = tl.log(d) # avoid redundant calculations - if label_smoothing > 0: - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load( - X_ptr + X_offsets, - mask=X_offsets < n_cols, - other=( - m + log_d - ), # out-of-bounds will become 0 after calculating softmax - ) - smooth_loss += -tl.sum(X_block - m - log_d) - # 4. [Online softmax] second pass: calculate the gradients # dx_y = (softmax(x_y) - 1) / N # dx_i = softmax(x_i) / N, i != y @@ -98,7 +89,6 @@ def liger_cross_entropy_kernel( # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N # = dx_i - (1 - label_smoothing) / N - eps = label_smoothing / n_cols for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( @@ -117,17 +107,19 @@ def liger_cross_entropy_kernel( # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 # So we can safely calculate log (softmax(X_y)) without overflow - loss = -(ori_X_y - m - log_d) + loss = -(ori_X_y - m - tl.log(d)) - # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) - # = (1 - label_smoothing) * H(q, p) + (label_smoothing / V) * sum(softmax(x_i)) + # = (1 - label_smoothing) * H(q, p) + eps * sum(softmax(x_i)) + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 if label_smoothing > 0: - loss = loss * (1 - label_smoothing) + smooth_loss * eps + smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + loss = loss * (1 - label_smoothing) + smooth_loss - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N` + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` X_y = tl.load(X_ptr + y) X_y += -(1 - label_smoothing) / (n_non_ignore) @@ -247,7 +239,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, _input, target, ignore_index, label_smoothing=0.0): + def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0): """ The forward pass of the Liger Cross Entropy loss. From 167083e1c2bfcb872be6120e3572a6e9eaef82cf Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Sat, 7 Sep 2024 03:26:26 +0800 Subject: [PATCH 08/10] Update suggested changes --- src/liger_kernel/ops/cross_entropy.py | 6 +++++- src/liger_kernel/ops/fused_linear_cross_entropy.py | 6 ++++-- test/transformers/test_cross_entropy.py | 6 +++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index ea5d7f3c1..901809d4d 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -65,6 +65,8 @@ def liger_cross_entropy_kernel( X_ptr + y ) # we need to store the original value of X_y for the loss calculation + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 scaled_x_sum = 0.0 eps = label_smoothing / n_cols @@ -111,10 +113,12 @@ def liger_cross_entropy_kernel( # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) - # = (1 - label_smoothing) * H(q, p) + eps * sum(softmax(x_i)) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 if label_smoothing > 0: smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) loss = loss * (1 - label_smoothing) + smooth_loss diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index df210c5be..bf0e3da48 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -172,7 +172,9 @@ def fused_linear_cross_entropy_backward( class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): + def forward( + ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0 + ): """ Fusing the last linear layer with cross-entropy loss Reference: https://github.com/mgmalek/efficient_cross_entropy @@ -189,7 +191,7 @@ def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): ignore_index: the index to ignore in the target """ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input, weight, target, bias, ignore_index + _input, weight, target, bias, ignore_index, label_smoothing ) # downcast to dtype and store for backward ctx.save_for_backward( diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index 6bcb6ce32..cdc7000e5 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -83,7 +83,7 @@ def _test_correctness_with_label_smoothing_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) -def _test_correctness_with_parameters_once( +def _test_correctness_with_label_smoothing_with_ignore_index_once( target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol ): torch.manual_seed(0) @@ -410,14 +410,14 @@ def test_correctness_with_label_smoothing_once( torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, reason="Needs 16GB+ GPU memory.", ) -def test_correctness_with_parameters_once( +def test_correctness_with_label_smoothing_with_ignore_index_once( B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol ): liger_ce = LigerCrossEntropyLoss( ignore_index=ignore_index, label_smoothing=label_smoothing, ) - _test_correctness_with_parameters_once( + _test_correctness_with_label_smoothing_with_ignore_index_once( liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol ) From ffd25ab915360e44a7c3b13813145de53d9376f1 Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Sat, 7 Sep 2024 04:48:45 +0800 Subject: [PATCH 09/10] Add assertion test for label_smoothing --- src/liger_kernel/transformers/cross_entropy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 1df0fdf7f..d79b30cb6 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -6,6 +6,9 @@ class LigerCrossEntropyLoss(CrossEntropyLoss): def __init__(self, *args, **kwargs): super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs) + assert (self.label_smoothing > 0) and ( + self.label_smoothing < 1 + ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" def forward(self, _input, target): return LigerCrossEntropyFunction.apply( From a036c15683d9e36243d0834f8c7d96d406ff65a3 Mon Sep 17 00:00:00 2001 From: Tcc0403 Date: Sat, 7 Sep 2024 04:58:09 +0800 Subject: [PATCH 10/10] Fix the assertion test for label_smoothing --- src/liger_kernel/transformers/cross_entropy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index d79b30cb6..0adb1cc87 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -6,8 +6,8 @@ class LigerCrossEntropyLoss(CrossEntropyLoss): def __init__(self, *args, **kwargs): super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs) - assert (self.label_smoothing > 0) and ( - self.label_smoothing < 1 + assert (self.label_smoothing >= 0) and ( + self.label_smoothing <= 1 ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" def forward(self, _input, target):