Skip to content

Commit

Permalink
RMSNorm aggregation (#255)
Browse files Browse the repository at this point in the history
## Summary
Resolve #179 

## Testing Done

- Hardware Type: RTX-3080
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <tangshao28@gmail.com>
  • Loading branch information
Tcc0403 and lancerts authored Oct 2, 2024
1 parent 665751e commit e1e9d2e
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 48 deletions.
109 changes: 67 additions & 42 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Modifications made by Yanning Chen, 2024.
"""

import math
import operator

import torch
Expand All @@ -20,6 +21,7 @@
calculate_settings,
compare_version,
ensure_contiguous,
torch_to_triton_dtype,
)

if compare_version("triton", operator.ge, "3.0.0"):
Expand Down Expand Up @@ -84,6 +86,10 @@ def _rms_norm_forward_kernel(
W_row = W_row.to(tl.float32)
X_row = X_row.to(tl.float32)

if casting_mode == _CASTING_MODE_NONE:
eps = eps.to(X_row_dtype)
offset = offset.to(X_row_dtype)

mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
rstd = rsqrt(mean_square + eps)

Expand All @@ -100,6 +106,9 @@ def _rms_norm_forward_kernel(

Y_row = X_row * (offset + W_row)

if casting_mode == _CASTING_MODE_GEMMA:
Y_row = Y_row.to(X_row_dtype)

tl.store(Y_ptr + col_offsets, Y_row, mask=mask)


Expand All @@ -109,14 +118,17 @@ def _rms_norm_backward_kernel(
dY_row_stride,
X_ptr,
X_row_stride,
X_dtype: tl.constexpr,
W_ptr,
W_row_stride,
RSTD_ptr,
RSTD_row_stride,
dW_ptr,
dW_row_stride,
n_rows,
n_cols,
offset,
rows_per_program: tl.constexpr,
casting_mode: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
Expand All @@ -125,54 +137,60 @@ def _rms_norm_backward_kernel(
dw = sum(dy * (x / RMS)). summation over BxT dimension
"""

row_idx = tl.program_id(0)
row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program
row_end = min((row_block_id + 1) * rows_per_program, n_rows)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols

dY_ptr += row_idx * dY_row_stride
X_ptr += row_idx * X_row_stride
RSTD_ptr += row_idx * RSTD_row_stride
dW_ptr += row_idx * dW_row_stride
dW_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)

dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
original_x_dtype = X_row.dtype

# Get cached rms
rstd_row = tl.load(RSTD_ptr)
dY_ptr += row_start * dY_row_stride
X_ptr += row_start * X_row_stride
RSTD_ptr += row_start

W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0.0)
W_row = W_row + offset

X_row = X_row.to(tl.float32)
for _ in range(row_start, row_end):
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0.0)
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0.0)

# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
m = (dY_row * W_row).to(tl.float32)
# Get cached rms
rstd_row = tl.load(RSTD_ptr)

elif casting_mode == _CASTING_MODE_GEMMA:
dY_row, W_row = (
dY_row.to(tl.float32),
W_row.to(tl.float32),
)
X_row = X_row.to(tl.float32)

m = dY_row * W_row
# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
m = (dY_row * W_row).to(tl.float32)

dX_row = rstd_row * m
elif casting_mode == _CASTING_MODE_GEMMA:
dY_row = dY_row.to(tl.float32)
m = dY_row * W_row
else:
m = dY_row * W_row

dX_row += (rstd_row) * (
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
)
dX_row = rstd_row * m

# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
dW_row = dY_row * (X_row * rstd_row).to(original_x_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row = dY_row * (X_row * rstd_row)
dX_row += (rstd_row) * (
-(1 / n_cols) * rstd_row * rstd_row * tl.sum(m * X_row, axis=0) * X_row
)

# calculate the gradient of W
if casting_mode == _CASTING_MODE_LLAMA:
dW_row += dY_row * (X_row * rstd_row).to(X_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row += dY_row * (X_row * rstd_row)

tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)
tl.store(dY_ptr + col_offsets, dX_row.to(X_dtype), mask=mask)

dY_ptr += dY_row_stride
X_ptr += X_row_stride
RSTD_ptr += RSTD_row_stride

tl.store(dW_ptr + row_block_id * dW_row_stride + col_offsets, dW_row, mask=mask)


_str_to_casting_mode = {
Expand Down Expand Up @@ -238,31 +256,38 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
dim = shape[-1]
dY = dY.view(-1, dim)
n_rows, n_cols = dY.shape
dW = torch.empty_like(
X,
dtype=(torch.float32 if casting_mode == _CASTING_MODE_GEMMA.value else W.dtype),
)

sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
# fp32 for numerical stability especially.
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)

if n_cols > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
rows_per_program = math.ceil(n_rows / sm_count)
grid = (sm_count,)
# Here we use dY to store the value of dX to save memory
_rms_norm_backward_kernel[(n_rows,)](
_rms_norm_backward_kernel[grid](
dY,
dY.stride(0),
X,
X.stride(0),
torch_to_triton_dtype[X.dtype],
W,
W.stride(0),
RSTD,
RSTD.stride(0),
dW,
dW.stride(0),
_dW,
_dW.stride(0),
n_rows,
n_cols,
offset,
rows_per_program,
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
dX = dY.view(*shape)
dW = torch.sum(dW, dim=0).to(W.dtype)
dW = _dW.sum(dim=0).to(W.dtype)
return dX, dW


Expand Down
8 changes: 8 additions & 0 deletions src/liger_kernel/ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch
import triton
import triton.language as tl
from packaging.version import Version


Expand Down Expand Up @@ -60,3 +61,10 @@ def compare_version(package: str, operator: Callable, target: str):
return False
pkg_version = Version(pkg.__version__)
return operator(pkg_version, Version(target))


torch_to_triton_dtype = {
torch.float32: tl.float32,
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
}
15 changes: 9 additions & 6 deletions test/transformers/test_rms_norm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from test.utils import assert_verbose_allclose, supports_bfloat16
from test.utils import assert_verbose_allclose, set_seed, supports_bfloat16

import pytest
import torch
Expand All @@ -9,6 +9,7 @@
from liger_kernel.transformers.functional import liger_rms_norm
from liger_kernel.transformers.rms_norm import LigerRMSNorm

set_seed(42)
torch.use_deterministic_algorithms(True)

# Only setting torch.use_deterministic_algorithms(True) might throw the following error:
Expand Down Expand Up @@ -75,8 +76,8 @@ def forward(self, x):
(2, 128, 512),
(4, 256, 1024),
(8, 512, 2048),
(16, 1024, 4096),
# # weird shapes
(8, 1024, 4096),
# # # weird shapes
(3, 423, 213),
(5, 123, 123),
(7, 341, 234),
Expand Down Expand Up @@ -121,7 +122,7 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_m
# reference (llama or gemma)
ref_rms = reference(hidden_size=hd).to("cuda").to(dtype)
ref_o = ref_rms(h1)
ref_o.backward(do.clone(), retain_graph=True)
ref_o.backward(do, retain_graph=True)

# triton
triton_rms = (
Expand All @@ -130,13 +131,15 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_m
.to(dtype)
)
triton_o = triton_rms(h2)
triton_o.backward(do.clone(), retain_graph=True)
triton_o.backward(do, retain_graph=True)

assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol)
assert_verbose_allclose(
ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol
)
assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol)
print(f"{h1.grad=}")
print(f"{h2.grad=}")
assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol, max_print=20)


@pytest.mark.parametrize(
Expand Down

0 comments on commit e1e9d2e

Please sign in to comment.