Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add label smoothing for cross entropy #198

Merged
merged 11 commits into from
Sep 6, 2024
39 changes: 33 additions & 6 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def liger_cross_entropy_kernel(
n_cols,
n_non_ignore,
ignore_index,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Expand All @@ -30,6 +31,7 @@ def liger_cross_entropy_kernel(
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""

Expand Down Expand Up @@ -63,12 +65,18 @@ def liger_cross_entropy_kernel(
X_ptr + y
) # we need to store the original value of X_y for the loss calculation

scaled_x_sum = 0.0
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
eps = label_smoothing / n_cols

for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
Expand All @@ -77,12 +85,16 @@ def liger_cross_entropy_kernel(
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
X_block = (tl.exp(X_block - m) / d) / (n_non_ignore)
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
Expand All @@ -97,9 +109,19 @@ def liger_cross_entropy_kernel(
# So we can safely calculate log (softmax(X_y)) without overflow
loss = -(ori_X_y - m - tl.log(d))

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - 1) / N`
# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(softmax(x_i))
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
loss = loss * (1 - label_smoothing) + smooth_loss

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
X_y += -1 / (n_non_ignore)
X_y += -(1 - label_smoothing) / (n_non_ignore)

tl.store(loss_ptr, loss)
tl.store(X_ptr + y, X_y)
Expand Down Expand Up @@ -147,7 +169,7 @@ def element_mul_kernel(
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)


def cross_entropy_forward(_input, target, ignore_index):
def cross_entropy_forward(_input, target, ignore_index, label_smoothing):
BT, V = _input.shape
n_rows = BT

Expand Down Expand Up @@ -175,6 +197,7 @@ def cross_entropy_forward(_input, target, ignore_index):
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
Expand Down Expand Up @@ -216,7 +239,7 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, _input, target, ignore_index):
def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0):
"""
The forward pass of the Liger Cross Entropy loss.

Expand All @@ -225,11 +248,14 @@ def forward(ctx, _input, target, ignore_index):
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.

Returns:
tensor: The computed loss.
"""
loss, _input = cross_entropy_forward(_input, target, ignore_index)
loss, _input = cross_entropy_forward(
_input, target, ignore_index, label_smoothing
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
Expand All @@ -254,4 +280,5 @@ def backward(ctx, grad_output):
_input,
None,
None,
None,
)
3 changes: 2 additions & 1 deletion src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def fused_linear_cross_entropy_forward(
_input, weight, target, bias=None, ignore_index=-100
_input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
):
dtype = (
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
Expand Down Expand Up @@ -80,6 +80,7 @@ def fused_linear_cross_entropy_forward(
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32,
)
Expand Down
4 changes: 3 additions & 1 deletion src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ def __init__(self, *args, **kwargs):
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)

def forward(self, _input, target):
return LigerCrossEntropyFunction.apply(_input, target, self.ignore_index)
return LigerCrossEntropyFunction.apply(
_input, target, self.ignore_index, self.label_smoothing
)
174 changes: 174 additions & 0 deletions test/transformers/test_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,61 @@ def _test_correctness_with_ignore_index_once(
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_with_label_smoothing_once(
target_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol
):
torch.manual_seed(0)
torch_ce = CrossEntropyLoss(label_smoothing=label_smoothing)

_tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar
_input = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)

output = torch_ce(_input, target)
output2 = target_ce(_input2, target)

assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_with_parameters_once(
Tcc0403 marked this conversation as resolved.
Show resolved Hide resolved
target_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
):
torch.manual_seed(0)
torch_ce = CrossEntropyLoss(
ignore_index=ignore_index, label_smoothing=label_smoothing
)

_tensor = torch.randn(B * T, V, device="cuda", dtype=dtype) * scalar
_input = _tensor.detach().clone().requires_grad_(True)
_input2 = _tensor.detach().clone().requires_grad_(True)

target = torch.randint(0, V, (B * T,), device="cuda", dtype=torch.long)

# Assign some random number of elements as ignore_index
num_elements_to_assign = torch.randint(
1, B * T // 2, (1,)
).item() # Random number of elements to set to ignore_index
indices_to_assign = torch.randperm(B * T)[
:num_elements_to_assign
] # Randomly select indices
target[indices_to_assign] = ignore_index

output = torch_ce(_input, target)
output2 = target_ce(_input2, target)

assert torch.allclose(output, output2, atol=atol, rtol=rtol)

output.backward()
output2.backward()
assert torch.allclose(_input.grad, _input2.grad, atol=atol, rtol=rtol)


def _test_correctness_not_last_layer_once(
target_ce, B, T, V, scalar, dtype, atol, rtol
):
Expand Down Expand Up @@ -248,6 +303,125 @@ def test_correctness_with_ignore_index(
)


@pytest.mark.parametrize(
"B, T, V, label_smoothing",
[
(2, 4096, 32000, 0.1), # llama2, mistral
(2, 4096, 32000, 0.1), # llama2, mistral
(1, 4096, 128256, 0.1), # llama3
# weird shapes
(3, 423, 32000, 0.1),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_label_smoothing_once(
B, T, V, label_smoothing, scalar, dtype, atol, rtol
):
liger_ce = LigerCrossEntropyLoss(label_smoothing=label_smoothing)
_test_correctness_with_label_smoothing_once(
liger_ce, B, T, V, label_smoothing, scalar, dtype, atol, rtol
)


@pytest.mark.parametrize(
"B, T, V, ignore_index, label_smoothing",
[
(2, 4096, 32000, 1, 0.1), # llama2, mistral
(2, 4096, 32000, -100, 0.2), # llama2, mistral
(1, 4096, 128256, 2, 0.1), # llama3
# weird shapes
(3, 423, 32000, -300, 0.2),
],
)
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
pytest.param(
0.1,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
1.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
pytest.param(
10.0,
torch.bfloat16,
1e-8,
5e-2,
marks=pytest.mark.skipif(
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
),
),
(0.1, torch.float32, 1e-8, 1e-6),
(1.0, torch.float32, 1e-8, 1e-6),
(10.0, torch.float32, 1e-8, 1e-6),
],
)
@pytest.mark.skipif(
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
reason="Needs 16GB+ GPU memory.",
)
def test_correctness_with_parameters_once(
B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
):
liger_ce = LigerCrossEntropyLoss(
ignore_index=ignore_index,
label_smoothing=label_smoothing,
)
_test_correctness_with_parameters_once(
liger_ce, B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
)


@pytest.mark.parametrize(
"B, T, V",
[
Expand Down
Loading