Skip to content

Commit

Permalink
Add label smoothing for cross entropy (linkedin#198)
Browse files Browse the repository at this point in the history
## Summary
Aim to solve linkedin#81.

## Details

### For loss:
Label smoothing regularization ( LSR ) by replacing the label
distribution $q(k) = \delta_{k,y}$ with
```math
q'(k) = (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}
```
Considering cross entropy with LSR is

```math
\begin{align}
L' = H(q', p) &= -\sum^K_{k=1}log\ {p(k)}q'(k) = -\sum^K_{k=1}log\ {p(k)}((1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K})\\
                    &=  -\sum^K_{k=1}log\ {p(k)}(1 - \epsilon)q(k)   -\sum^K_{k=1}log\ {p(k)}\frac{\epsilon}{K} \\
                    &= (1 - \epsilon)H(q,p) + \frac{\epsilon}{K} \sum^K_{k=1} log\ softmax(x_k)\\
                    &= (1- \epsilon)L + \frac{\epsilon}{K}\ SmoothLoss,

\end{align}
```
where $L = H(q,p)$ is the original loss and $\sum^K_{k=1} log\
softmax(x_k)$ is smooth loss.

### For gradients:
The original:
```math
\begin{align}
\frac{\partial L}{\partial x_i} &= p(k) - q(k)\\
                                            &= \begin{cases}
                                                   softmax(x_i) ,                        &  i \neq y \\
                                                   softmax(x_i) - 1,                    &  i = y
                                                   
\end{cases}
\end{align} 
```
With LSR:
```math
\begin{align}
\frac{\partial L'}{\partial x_i} &= p(k) - q'(k)\\
                                            &= softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}\\
                                            &= \begin{cases} softmax(x_i) - \frac{\epsilon}{K},                        &  i \neq y \\
                                                   softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) &  i = y

\end{cases}
\end{align}
```

We can handle the $i = y$ case by simply adding $-(1-\epsilon)$ after
computing all $i$.


Reference:
[Rethinking the Inception Architecture for Computer
Vision](https://arxiv.org/abs/1512.00567)

## Testing Done
Add a unit test for label smoothing.

- Hardware Type: RTX-3080
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
```bash
❯ python3 -m pytest test/transformers/test_cross_entropy.py
============================================ test session starts =============================================
platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tcc/Liger-Kernel
collected 94 items

test/transformers/test_cross_entropy.py .............................................................. [ 65%]
...............................F                                                                       [100%]

================================================== FAILURES ==================================================
__________________________________ test_large_no_exception[8-16384-128256] ___________________________________

B = 8, T = 16384, V = 128256

    @pytest.mark.parametrize(
        "B, T, V",
        [
            (
                8,
                8192,
                128256,
            ),  # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64
            (8, 16384, 128256),  # _input = 32GB, total = ~64GB
        ],
    )
    # @pytest.mark.skipif(
    #     torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000,
    #     reason="Needs 64GB+ GPU memory.",
    # )
    def test_large_no_exception(B, T, V):
        # The large inputs were hitting cuda illegal memory access because of
        # triton-lang/triton#1058
>       _full_pass_once(B, T, V)

test/transformers/test_cross_entropy.py:401:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

B = 8, T = 16384, V = 128256

    def _full_pass_once(B, T, V):
        torch.manual_seed(0)
        liger_ce = LigerCrossEntropyLoss()

>       _input = torch.randn(
            B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16
        )
E       torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10.00 GiB of which 8.84 GiB is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 0 bytes is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

test/transformers/test_cross_entropy.py:374: OutOfMemoryError
========================================== short test summary info ===========================================
FAILED test/transformers/test_cross_entropy.py::test_large_no_exception[8-16384-128256] - torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 31.31 GiB. GPU 0 has a total capacity of 10...
================================== 1 failed, 93 passed in 130.88s (0:02:10) ==================================
```
```bash
❯ make test
python -m pytest --disable-warnings test/ --ignore=test/convergence
============================================ test session starts =============================================
platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tcc/Liger-Kernel
collected 256 items

test/transformers/test_auto_model.py .                                                                 [  0%]
test/transformers/test_cross_entropy.py ssssssssssssssssssssssss............ssssssssssssssssssssssssss [ 24%]
ssssssssssssssssssssssssssssssss                                                                       [ 37%]
test/transformers/test_embedding.py ...........                                                        [ 41%]
test/transformers/test_fused_linear_cross_entropy.py ................                                  [ 47%]
test/transformers/test_geglu.py ............                                                           [ 52%]
test/transformers/test_layer_norm.py ................                                                  [ 58%]
test/transformers/test_monkey_patch.py .....                                                           [ 60%]
test/transformers/test_rms_norm.py ............................................................        [ 83%]
test/transformers/test_rope.py ..................                                                      [ 91%]
test/transformers/test_swiglu.py ....................                                                  [ 98%]
test/transformers/test_trainer_integration.py .                                                        [ 99%]
test/triton/test_triton_monkey_patch.py ..                                                             [100%]

================================ 174 passed, 82 skipped in 123.06s (0:02:03) =================================
```
```bash
❯ make checkstyle
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
Skipped 2 files
All done! ✨ 🍰 ✨
68 files left unchanged.
```
```bash
❯ make test-convergence
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
============================================ test session starts =============================================
platform linux -- Python 3.10.12, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tcc/Liger-Kernel
collected 30 items

test/convergence/test_mini_models.py ..............                                                    [ 46%]
test/convergence/test_mini_models_no_logits.py ................                                        [100%]

======================================= 30 passed in 223.18s (0:03:43) =======================================
```
  • Loading branch information
Tcc0403 authored Sep 6, 2024
1 parent 376fe0c commit 43cbd4e
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 10 deletions.
43 changes: 37 additions & 6 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def liger_cross_entropy_kernel(
n_cols,
n_non_ignore,
ignore_index,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Expand All @@ -30,6 +31,7 @@ def liger_cross_entropy_kernel(
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""

Expand Down Expand Up @@ -63,12 +65,20 @@ def liger_cross_entropy_kernel(
X_ptr + y
) # we need to store the original value of X_y for the loss calculation

# Label smoothing is a general case of normal cross entropy
# See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310
scaled_x_sum = 0.0
eps = label_smoothing / n_cols

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
Expand All @@ -77,12 +87,16 @@ def liger_cross_entropy_kernel(
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
X_block = (tl.exp(X_block - m) / d) / (n_non_ignore)
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
Expand All @@ -97,9 +111,21 @@ def liger_cross_entropy_kernel(
# So we can safely calculate log (softmax(X_y)) without overflow
loss = -(ori_X_y - m - tl.log(d))

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N`
# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
loss = loss * (1 - label_smoothing) + smooth_loss

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
X_y += -1 / (n_non_ignore)
X_y += -(1 - label_smoothing) / (n_non_ignore)

tl.store(loss_ptr, loss)
tl.store(X_ptr + y, X_y)
Expand Down Expand Up @@ -147,7 +173,7 @@ def element_mul_kernel(
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)


def cross_entropy_forward(_input, target, ignore_index):
def cross_entropy_forward(_input, target, ignore_index, label_smoothing):
BT, V = _input.shape
n_rows = BT

Expand Down Expand Up @@ -175,6 +201,7 @@ def cross_entropy_forward(_input, target, ignore_index):
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
Expand Down Expand Up @@ -216,7 +243,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, _input, target, ignore_index):
def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0):
"""
The forward pass of the Liger Cross Entropy loss.
Expand All @@ -225,11 +252,14 @@ def forward(ctx, _input, target, ignore_index):
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Returns:
tensor: The computed loss.
"""
loss, _input = cross_entropy_forward(_input, target, ignore_index)
loss, _input = cross_entropy_forward(
_input, target, ignore_index, label_smoothing
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
Expand All @@ -254,4 +284,5 @@ def backward(ctx, grad_output):
_input,
None,
None,
None,
)
9 changes: 6 additions & 3 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def fused_linear_cross_entropy_forward(
_input, weight, target, bias=None, ignore_index=-100
_input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
):
dtype = (
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
Expand Down Expand Up @@ -80,6 +80,7 @@ def fused_linear_cross_entropy_forward(
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
Expand Down Expand Up @@ -171,7 +172,9 @@ def fused_linear_cross_entropy_backward(

class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, _input, weight, target, bias=None, ignore_index=-100):
def forward(
ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
):
"""
Fusing the last linear layer with cross-entropy loss
Reference: https://github.com/mgmalek/efficient_cross_entropy
Expand All @@ -188,7 +191,7 @@ def forward(ctx, _input, weight, target, bias=None, ignore_index=-100):
ignore_index: the index to ignore in the target
"""
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input, weight, target, bias, ignore_index
_input, weight, target, bias, ignore_index, label_smoothing
)
# downcast to dtype and store for backward
ctx.save_for_backward(
Expand Down
7 changes: 6 additions & 1 deletion src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
class LigerCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, *args, **kwargs):
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
assert (self.label_smoothing >= 0) and (
self.label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"

def forward(self, _input, target):
return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index)
return LigerCrossEntropyFunction.apply(
_input, target, self.ignore_index, self.label_smoothing
)
174 changes: 174 additions & 0 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,61 @@ def _test_correctness_with_ignore_index_once(
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_with_label_smoothing_once(
target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol
):
torch.manual_seed(0)
torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing)

_tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar
_input = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)

output = torch_ce(_input, target)
output2 = target_ce(_input2, target)

assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_with_label_smoothing_with_ignore_index_once(
target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
):
torch.manual_seed(0)
torch_ce = CrossEntropyLoss(
ignore_index=ignore_index, label_smoothing=label_smoothing
)

_tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar
_input = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)

# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(
1, B * T // 2, (1,)
).item() # Random number of elements to set to ignore_index
indices_to_assign = torch.randperm(B * T)[
:num_elements_to_assign
] # Randomly select indices
target[indices_to_assign] = ignore_index

output = torch_ce(_input, target)
output2 = target_ce(_input2, target)

assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_not_last_layer_once(
target_ce, B, T, V, scalar, dtype, atol, rtol
):
Expand Down Expand Up @@ -248,6 +303,125 @@ def test_correctness_with_ignore_index(
)


@pytest.mark.parametrize(
"B, T, V, label_smoothing",
[
(2, 4096, 32000, 0.1), # llama2, mistral
(2, 4096, 32000, 0.1), # llama2, mistral
(1, 4096, 128256, 0.1), # llama3
# weird shapes
(3, 423, 32000, 0.1),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_label_smoothing_once(
B, T, V, label_smoothing, scalar, dtype, atol, rtol
):
liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing)
_test_correctness_with_label_smoothing_once(
liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol
)


@pytest.mark.parametrize(
"B, T, V, ignore_index, label_smoothing",
[
(2, 4096, 32000, 1, 0.1), # llama2, mistral
(2, 4096, 32000, -100, 0.2), # llama2, mistral
(1, 4096, 128256, 2, 0.1), # llama3
# weird shapes
(3, 423, 32000, -300, 0.2),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_label_smoothing_with_ignore_index_once(
B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
):
liger_ce = LigerCrossEntropyLoss(
ignore_index=ignore_index,
label_smoothing=label_smoothing,
)
_test_correctness_with_label_smoothing_with_ignore_index_once(
liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
)


@pytest.mark.parametrize(
"B, T, V",
[
Expand Down

0 comments on commit 43cbd4e

Please sign in to comment.