Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Z Loss in CE #239

Merged
merged 37 commits into from
Nov 7, 2024
Merged
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
0454a12
Implement z loss in LigerCrossEntropyFunction
Tcc0403 Sep 9, 2024
9349e89
Merge branch 'main' into z-loss
lancerts Sep 9, 2024
27783be
Merge branch 'main' into z-loss
lancerts Sep 9, 2024
02e90db
Rename z_loss_scale to lse_square_scale
Tcc0403 Sep 9, 2024
aa43dca
Merge branch 'z-loss' of github.com:Tcc0403/Liger-Kernel into z-loss
Tcc0403 Sep 10, 2024
aa4a4b2
Fix a mistake of the gradient calculation and update comments
Tcc0403 Sep 10, 2024
f53f61c
Remove the parameter `lse_square_scale` in FusedLinearCrossEntropyLos…
Tcc0403 Sep 10, 2024
b43c457
Implement z loss in LigerCrossEntropyFunction
Tcc0403 Sep 9, 2024
59bc0a3
Rename z_loss_scale to lse_square_scale
Tcc0403 Sep 9, 2024
0921c81
Fix a mistake of the gradient calculation and update comments
Tcc0403 Sep 10, 2024
c19f69c
Remove the parameter `lse_square_scale` in FusedLinearCrossEntropyLos…
Tcc0403 Sep 10, 2024
83c99ad
Merge branch 'z-loss' of github.com:Tcc0403/Liger-Kernel into z-loss
Tcc0403 Sep 10, 2024
1ee07de
Merge branch 'main' into ce-z-loss
Tcc0403 Sep 10, 2024
83f23d0
Support z loss in flce
Tcc0403 Sep 11, 2024
fcd5ff4
Merge branch 'main' into ce-z-loss
Tcc0403 Sep 11, 2024
295aab7
Merge branch 'main' into ce-z-loss
Tcc0403 Sep 13, 2024
f72e9bb
Fix parameter orders of ce and flce
Tcc0403 Sep 13, 2024
10fa578
Fix functional tests
Tcc0403 Sep 14, 2024
03beb05
Fix bfloat16 precision issue on custom model
Tcc0403 Sep 14, 2024
3a6cad4
Add missing arguments in test and cleanup stdout
Tcc0403 Sep 14, 2024
7e4cc4b
Merge branch 'main' into ce-z-loss
lancerts Sep 19, 2024
c0f2581
Merge branch 'main' into ce-z-loss
lancerts Sep 21, 2024
9abd163
Merge branch 'main' into ce-z-loss
Tcc0403 Sep 28, 2024
5c24241
Merge branch 'main' into ce-z-loss
Tcc0403 Oct 1, 2024
97db6b4
Merge branch 'main' into ce-z-loss
lancerts Oct 1, 2024
cf632d8
Merge branch 'main' into ce-z-loss
lancerts Oct 3, 2024
91b62fd
Merge branch 'main' into ce-z-loss
Tcc0403 Oct 12, 2024
d2d6e44
Fix merge conflicts
Tcc0403 Oct 12, 2024
f7083f2
Merge branch 'ce-z-loss' of github.com:Tcc0403/Liger-Kernel into ce-z…
Tcc0403 Oct 12, 2024
b89f335
Merge branch 'main' into ce-z-loss
Tcc0403 Oct 27, 2024
9a6079a
Merge branch 'main' into ce-z-loss
Tcc0403 Nov 2, 2024
c8d0fac
Merge branch 'main' into ce-z-loss
Tcc0403 Nov 5, 2024
c957357
chekcstyle
Tcc0403 Nov 5, 2024
4e34bf2
Merge branch 'main' into ce-z-loss
ByronHsu Nov 6, 2024
c304cc3
Merge branch 'main' into ce-z-loss
ByronHsu Nov 6, 2024
fb7aff7
Merge branch 'main' into ce-z-loss
ByronHsu Nov 7, 2024
d2ab058
Update src/liger_kernel/ops/cross_entropy.py
ByronHsu Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix merge conflicts
Tcc0403 committed Oct 12, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit d2d6e44cbebbe590b09b139eb3c9410037a1e6a5
40 changes: 2 additions & 38 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -2,9 +2,10 @@
import triton
import triton.language as tl

from liger_kernel.ops.utils import element_mul_kernel

_TRUE = tl.constexpr(1)
_FALSE = tl.constexpr(0)
from liger_kernel.ops.utils import element_mul_kernel


@triton.jit
@@ -186,42 +187,6 @@ def liger_cross_entropy_kernel(
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning


@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.

Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""

# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)

# Locate the start index
X_ptr += program_id * X_stride

# Load the gradient output value
grad_output = tl.load(grad_output_ptr)

# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)


_bool_to_return_z_loss = {
True: _TRUE.value,
False: _FALSE.value,
@@ -247,7 +212,6 @@ def cross_entropy_forward(
return_z_loss in _bool_to_return_z_loss
), f"return_z_loss must be True or False. Got: {return_z_loss}"

def cross_entropy_forward(_input, target, ignore_index, label_smoothing, reduction):
BT, V = _input.shape
n_rows = BT