Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add weight support for LigerCrossEntropy #420

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 53 additions & 5 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,21 @@ def liger_cross_entropy_kernel(
X_stride,
Y_ptr,
Y_stride,
weight_ptr,
loss_ptr,
z_loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
weight_sum,
ignore_index,
lse_square_scale: tl.constexpr,
label_smoothing: tl.constexpr,
reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time
softcap,
RETURN_Z_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_SOFTCAPPING: tl.constexpr,
):
"""
Expand All @@ -50,18 +53,22 @@ def liger_cross_entropy_kernel(
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
weight_ptr: Pointer to weight tensor.
weight_stride (int): The stride of the weight tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
n_non_ignore (flaot): The number of non-ignored elements or the sum of non-ignored target's weights in the batch
weight_sum (float): The sum of weigh tensor
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
reduction (str): The string for the reduction to apply
softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap).
BLOCK_SIZE (int): The block size for Triton operations.
HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
"""

Expand All @@ -86,6 +93,9 @@ def liger_cross_entropy_kernel(
loss_ptr += program_id * loss_stride
z_loss_ptr += program_id * loss_stride

if HAS_WEIGHT:
weight_y = tl.load(weight_ptr + y).cast(tl.float32)

# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867

Expand Down Expand Up @@ -116,7 +126,15 @@ def liger_cross_entropy_kernel(
block_max = tl.max(X_block)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
if HAS_WEIGHT:
weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols)
scaled_x_sum += tl.sum(
tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)
)
else:
scaled_x_sum += tl.sum(
tl.where(X_offsets < n_cols, -eps * X_block, 0.0)
)
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
Expand Down Expand Up @@ -163,6 +181,8 @@ def liger_cross_entropy_kernel(
# reduction scale
if reduction == "mean":
X_block = X_block / (n_non_ignore)
if HAS_WEIGHT:
X_block = X_block * weight_y
# chain rule
# d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
if HAS_SOFTCAPPING:
Expand All @@ -182,6 +202,8 @@ def liger_cross_entropy_kernel(
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
# So we can safely calculate log (softmax(X_y)) without overflow
loss = lse - ori_X_y
if HAS_WEIGHT:
loss = weight_y * loss

# Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
Expand All @@ -192,7 +214,10 @@ def liger_cross_entropy_kernel(
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * lse
if HAS_WEIGHT:
smooth_loss = scaled_x_sum + eps * lse * weight_sum
else:
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss

# An auxiliary loss, z_loss
Expand All @@ -203,7 +228,6 @@ def liger_cross_entropy_kernel(
if reduction == "mean":
z_loss = z_loss / n_non_ignore
loss = loss / n_non_ignore

tl.store(loss_ptr, loss)
if RETURN_Z_LOSS == _TRUE:
tl.store(z_loss_ptr, z_loss)
Expand All @@ -224,6 +248,7 @@ def liger_cross_entropy_kernel(
def cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down Expand Up @@ -253,7 +278,23 @@ def cross_entropy_forward(
else:
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False

n_non_ignore = (target != ignore_index).sum().item()
target_mask = target != ignore_index
n_non_ignore = target_mask.sum().item()
weight_sum = weight.sum().item()
if weight is not None:
assert (
weight.shape[0] == V
), f"If given, weight has to be a Tensor of size V. Got: {weight.shape}"
assert torch.is_floating_point(
weight
), f"If given, weight has to be a Tensor of floating point dtype. Got: {weight.dtype}"
n_non_ignore = (
torch.gather(weight, dim=0, index=target.masked_select(target_mask))
.sum()
.item()
)
if weight.stride(-1) != 1:
weight = weight.contiguous()

# ensure _input and target are contiguous in the last dimension
if _input.stride(-1) != 1:
Expand All @@ -267,18 +308,21 @@ def cross_entropy_forward(
X_stride=_input.stride(-2),
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
weight_ptr=weight if weight is not None else _input, # dummy if None
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
weight_sum=weight_sum,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=return_z_loss,
BLOCK_SIZE=BLOCK_SIZE,
HAS_WEIGHT=True if weight is not None else False,
HAS_SOFTCAPPING=True if softcap is not None else False,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
Expand Down Expand Up @@ -330,6 +374,7 @@ def forward(
ctx,
_input: torch.Tensor,
target: torch.Tensor,
weight: Optional[torch.FloatTensor],
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -344,6 +389,7 @@ def forward(
ctx : The context object.
_input (tensor): The input tensor of shape (BT, V) where B is batch size, T is sequence length, V is vocab size.
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
weight(Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype
ignore_index (int): The index to ignore in the target.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
Expand All @@ -357,6 +403,7 @@ def forward(
loss, z_loss, _input = cross_entropy_forward(
_input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down Expand Up @@ -398,4 +445,5 @@ def backward(ctx, grad_output, grad_ouput2):
None,
None,
None,
None,
)
4 changes: 4 additions & 0 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,21 @@ def fused_linear_cross_entropy_forward(
X_stride=logits_chunk.stride(-2),
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
weight_ptr=_input, # dummy ptr, not used
weight_stride=0,
loss_ptr=loss_1d_slice,
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
loss_stride=loss_1d_slice.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
sum_of_non_ignore_weight=n_non_ignore,
ignore_index=ignore_index,
lse_square_scale=lse_square_scale,
label_smoothing=label_smoothing,
reduction=reduction,
softcap=softcap if softcap is not None else 0.0,
RETURN_Z_LOSS=0, # False
HAS_WEIGHT=False,
HAS_SOFTCAPPING=True if softcap is not None else False,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=32 if not is_hip() else 16,
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class LigerCrossEntropyLoss(torch.nn.Module):
def __init__(
self,
weight: Optional[torch.FloatTensor] = None,
ignore_index: int = -100,
lse_square_scale: float = 0.0,
label_smoothing: float = 0.0,
Expand All @@ -30,6 +31,7 @@ def __init__(
assert (
softcap is None or softcap > 0
), f"softcap must greater than 0.0 or None. Got: {softcap}"
self.weight = weight
self.ignore_index = ignore_index
self.lse_square_scale = lse_square_scale
self.label_smoothing = label_smoothing
Expand All @@ -41,6 +43,7 @@ def forward(self, _input: torch.Tensor, target: torch.Tensor):
loss, z_loss = LigerCrossEntropyFunction.apply(
_input,
target,
self.weight,
self.ignore_index,
self.lse_square_scale,
self.label_smoothing,
Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def liger_cross_entropy(
loss, z_loss = LigerCrossEntropyFunction.apply(
input,
target,
weight,
ignore_index,
lse_square_scale,
label_smoothing,
Expand Down
Loading
Loading