From 43cbd4e6b250218b2008cf81504b5dc9763ac228 Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Sat, 7 Sep 2024 05:07:01 +0800 Subject: [PATCH] Add label smoothing for cross entropy (#198) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Aim to solve #81. ## Details ### For loss: Label smoothing regularization ( LSR ) by replacing the label distribution $q(k) = \delta_{k,y}$ with ```math q'(k) = (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K} ``` Considering cross entropy with LSR is ```math \begin{align} L' = H(q', p) &= -\sum^K_{k=1}log\ {p(k)}q'(k) = -\sum^K_{k=1}log\ {p(k)}((1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K})\\ &= -\sum^K_{k=1}log\ {p(k)}(1 - \epsilon)q(k) -\sum^K_{k=1}log\ {p(k)}\frac{\epsilon}{K} \\ &= (1 - \epsilon)H(q,p) + \frac{\epsilon}{K} \sum^K_{k=1} log\ softmax(x_k)\\ &= (1- \epsilon)L + \frac{\epsilon}{K}\ SmoothLoss, \end{align} ``` where $L = H(q,p)$ is the original loss and $\sum^K_{k=1} log\ softmax(x_k)$ is smooth loss. ### For gradients: The original: ```math \begin{align} \frac{\partial L}{\partial x_i} &= p(k) - q(k)\\ &= \begin{cases} softmax(x_i) , & i \neq y \\ softmax(x_i) - 1, & i = y \end{cases} \end{align} ``` With LSR: ```math \begin{align} \frac{\partial L'}{\partial x_i} &= p(k) - q'(k)\\ &= softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}\\ &= \begin{cases} softmax(x_i) - \frac{\epsilon}{K}, & i \neq y \\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) & i = y \end{cases} \end{align} ``` We can handle the $i = y$ case by simply adding $-(1-\epsilon)$ after computing all $i$. Reference: [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567) ## Testing Done Add a unit test for label smoothing. - Hardware Type: RTX-3080 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ```bash ❯ python3 -m pytest test/transformers/test_cross_entropy.py ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 94 items test/transformers/test_cross_entropy.py .............................................................. [ 65%] ...............................F [100%] ================================================== FAILURES ================================================== __________________________________ test_large_no_exception[8-16384-128256] ___________________________________ B = 8, T = 16384, V = 128256 @pytest.mark.parametrize( "B, T, V", [ ( 8, 8192, 128256, ), # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64 (8, 16384, 128256), # _input = 32GB, total = ~64GB ], ) # @pytest.mark.skipif( # torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000, # reason="Needs 64GB+ GPU memory.", # ) def test_large_no_exception(B, T, V): # The large inputs were hitting cuda illegal memory access because of # https://github.com/triton-lang/triton/issues/1058 > _full_pass_once(B, T, V) test/transformers/test_cross_entropy.py:401: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ B = 8, T = 16384, V = 128256 def _full_pass_once(B, T, V): torch.manual_seed(0) liger_ce = LigerCrossEntropyLoss() > _input = torch.randn( B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16 ) E torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10.00 GiB of which 8.84 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 0 bytes is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) test/transformers/test_cross_entropy.py:374: OutOfMemoryError ========================================== short test summary info =========================================== FAILED test/transformers/test_cross_entropy.py::test_large_no_exception[8-16384-128256] - torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10... ================================== 1 failed, 93 passed in 130.88s (0:02:10) ================================== ``` ```bash ❯ make test python -m pytest --disable-warnings test/ --ignore=test/convergence ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 256 items test/transformers/test_auto_model.py . [ 0%] test/transformers/test_cross_entropy.py ssssssssssssssssssssssss............ssssssssssssssssssssssssss [ 24%] ssssssssssssssssssssssssssssssss [ 37%] test/transformers/test_embedding.py ........... [ 41%] test/transformers/test_fused_linear_cross_entropy.py ................ [ 47%] test/transformers/test_geglu.py ............ [ 52%] test/transformers/test_layer_norm.py ................ [ 58%] test/transformers/test_monkey_patch.py ..... [ 60%] test/transformers/test_rms_norm.py ............................................................ [ 83%] test/transformers/test_rope.py .................. [ 91%] test/transformers/test_swiglu.py .................... [ 98%] test/transformers/test_trainer_integration.py . [ 99%] test/triton/test_triton_monkey_patch.py .. [100%] ================================ 174 passed, 82 skipped in 123.06s (0:02:03) ================================= ``` ```bash ❯ make checkstyle flake8 .; flake8_status=$?; \ isort .; isort_status=$?; \ black .; black_status=$?; \ if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \ exit 1; \ fi Skipped 2 files All done! ✨ 🍰 ✨ 68 files left unchanged. ``` ```bash ❯ make test-convergence HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence ============================================ test session starts ============================================= platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel collected 30 items test/convergence/test_mini_models.py .............. [ 46%] test/convergence/test_mini_models_no_logits.py ................ [100%] ======================================= 30 passed in 223.18s (0:03:43) ======================================= ``` --- src/liger_kernel/ops/cross_entropy.py | 43 ++++- .../ops/fused_linear_cross_entropy.py | 9 +- .../transformers/cross_entropy.py | 7 +- test/transformers/test_cross_entropy.py | 174 ++++++++++++++++++ 4 files changed, 223 insertions(+), 10 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 32703788c..901809d4d 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -14,6 +14,7 @@ def liger_cross_entropy_kernel( n_cols, n_non_ignore, ignore_index, + label_smoothing: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ @@ -30,6 +31,7 @@ def liger_cross_entropy_kernel( n_cols (int): The number of columns in the input tensor. n_non_ignore (int): The number of non-ignored elements in the batch. ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. BLOCK_SIZE (int): The block size for Triton operations. """ @@ -63,12 +65,20 @@ def liger_cross_entropy_kernel( X_ptr + y ) # we need to store the original value of X_y for the loss calculation + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) m_new = tl.maximum(m, block_max) d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) m = m_new @@ -77,12 +87,16 @@ def liger_cross_entropy_kernel( # dx_y = (softmax(x_y) - 1) / N # dx_i = softmax(x_i) / N, i != y # N is the number of non ignored elements in the batch + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N for i in range(0, n_cols, BLOCK_SIZE): X_offsets = i + tl.arange(0, BLOCK_SIZE) X_block = tl.load( X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf") ) - X_block = (tl.exp(X_block - m) / d) / (n_non_ignore) + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in @@ -97,9 +111,21 @@ def liger_cross_entropy_kernel( # So we can safely calculate log (softmax(X_y)) without overflow loss = -(ori_X_y - m - tl.log(d)) - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N` + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + loss = loss * (1 - label_smoothing) + smooth_loss + + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` X_y = tl.load(X_ptr + y) - X_y += -1 / (n_non_ignore) + X_y += -(1 - label_smoothing) / (n_non_ignore) tl.store(loss_ptr, loss) tl.store(X_ptr + y, X_y) @@ -147,7 +173,7 @@ def element_mul_kernel( tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) -def cross_entropy_forward(_input, target, ignore_index): +def cross_entropy_forward(_input, target, ignore_index, label_smoothing): BT, V = _input.shape n_rows = BT @@ -175,6 +201,7 @@ def cross_entropy_forward(_input, target, ignore_index): n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, # TODO: 32 seems to give the best performance # Performance is quite sensitive to num_warps @@ -216,7 +243,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, _input, target, ignore_index): + def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0): """ The forward pass of the Liger Cross Entropy loss. @@ -225,11 +252,14 @@ def forward(ctx, _input, target, ignore_index): _input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size. target (tensor): The target tensor of shape (BT) where each value is in [0, V-1]. ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. Returns: tensor: The computed loss. """ - loss, _input = cross_entropy_forward(_input, target, ignore_index) + loss, _input = cross_entropy_forward( + _input, target, ignore_index, label_smoothing + ) # TODO: investigation # If we don't detach the _input tensor, the memory will double # Not sure why but seems that there will be a time both grad and value exist but in different location @@ -254,4 +284,5 @@ def backward(ctx, grad_output): _input, None, None, + None, ) diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 7b62dbbb1..bf0e3da48 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -13,7 +13,7 @@ def fused_linear_cross_entropy_forward( - _input, weight, target, bias=None, ignore_index=-100 + _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0 ): dtype = ( torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype @@ -80,6 +80,7 @@ def fused_linear_cross_entropy_forward( n_cols=V, n_non_ignore=n_non_ignore, ignore_index=ignore_index, + label_smoothing=label_smoothing, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, ) @@ -171,7 +172,9 @@ def fused_linear_cross_entropy_backward( class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): @staticmethod - def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): + def forward( + ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0 + ): """ Fusing the last linear layer with cross-entropy loss Reference: https://github.com/mgmalek/efficient_cross_entropy @@ -188,7 +191,7 @@ def forward(ctx, _input, weight, target, bias=None, ignore_index=-100): ignore_index: the index to ignore in the target """ loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( - _input, weight, target, bias, ignore_index + _input, weight, target, bias, ignore_index, label_smoothing ) # downcast to dtype and store for backward ctx.save_for_backward( diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index 44255eb85..0adb1cc87 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -6,6 +6,11 @@ class LigerCrossEntropyLoss(CrossEntropyLoss): def __init__(self, *args, **kwargs): super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs) + assert (self.label_smoothing >= 0) and ( + self.label_smoothing <= 1 + ), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" def forward(self, _input, target): - return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index) + return LigerCrossEntropyFunction.apply( + _input, target, self.ignore_index, self.label_smoothing + ) diff --git a/test/transformers/test_cross_entropy.py b/test/transformers/test_cross_entropy.py index c1fcccb34..cdc7000e5 100644 --- a/test/transformers/test_cross_entropy.py +++ b/test/transformers/test_cross_entropy.py @@ -61,6 +61,61 @@ def _test_correctness_with_ignore_index_once( assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) +def _test_correctness_with_label_smoothing_once( + target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing) + + _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + +def _test_correctness_with_label_smoothing_with_ignore_index_once( + target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol +): + torch.manual_seed(0) + torch_ce = CrossEntropyLoss( + ignore_index=ignore_index, label_smoothing=label_smoothing + ) + + _tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar + _input = _tensor.detach().clone().requires_grad_(True) + _input2 = _tensor.detach().clone().requires_grad_(True) + + target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long) + + # Assign some random number of elements as ignore_index + num_elements_to_assign = torch.randint( + 1, B * T // 2, (1,) + ).item() # Random number of elements to set to ignore_index + indices_to_assign = torch.randperm(B * T)[ + :num_elements_to_assign + ] # Randomly select indices + target[indices_to_assign] = ignore_index + + output = torch_ce(_input, target) + output2 = target_ce(_input2, target) + + assert torch.allclose(output, output2, atol=atol, rtol=rtol) + + output.backward() + output2.backward() + assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol) + + def _test_correctness_not_last_layer_once( target_ce, B, T, V, scalar, dtype, atol, rtol ): @@ -248,6 +303,125 @@ def test_correctness_with_ignore_index( ) +@pytest.mark.parametrize( + "B, T, V, label_smoothing", + [ + (2, 4096, 32000, 0.1), # llama2, mistral + (2, 4096, 32000, 0.1), # llama2, mistral + (1, 4096, 128256, 0.1), # llama3 + # weird shapes + (3, 423, 32000, 0.1), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 0.1, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 10.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (0.1, torch.float32, 1e-8, 1e-6), + (1.0, torch.float32, 1e-8, 1e-6), + (10.0, torch.float32, 1e-8, 1e-6), + ], +) +@pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, + reason="Needs 16GB+ GPU memory.", +) +def test_correctness_with_label_smoothing_once( + B, T, V, label_smoothing, scalar, dtype, atol, rtol +): + liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing) + _test_correctness_with_label_smoothing_once( + liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol + ) + + +@pytest.mark.parametrize( + "B, T, V, ignore_index, label_smoothing", + [ + (2, 4096, 32000, 1, 0.1), # llama2, mistral + (2, 4096, 32000, -100, 0.2), # llama2, mistral + (1, 4096, 128256, 2, 0.1), # llama3 + # weird shapes + (3, 423, 32000, -300, 0.2), + ], +) +@pytest.mark.parametrize( + "scalar, dtype, atol, rtol", + [ + pytest.param( + 0.1, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 1.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + pytest.param( + 10.0, + torch.bfloat16, + 1e-8, + 5e-2, + marks=pytest.mark.skipif( + not supports_bfloat16(), reason="bfloat16 not supported on this GPU" + ), + ), + (0.1, torch.float32, 1e-8, 1e-6), + (1.0, torch.float32, 1e-8, 1e-6), + (10.0, torch.float32, 1e-8, 1e-6), + ], +) +@pytest.mark.skipif( + torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000, + reason="Needs 16GB+ GPU memory.", +) +def test_correctness_with_label_smoothing_with_ignore_index_once( + B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol +): + liger_ce = LigerCrossEntropyLoss( + ignore_index=ignore_index, + label_smoothing=label_smoothing, + ) + _test_correctness_with_label_smoothing_with_ignore_index_once( + liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol + ) + + @pytest.mark.parametrize( "B, T, V", [