Skip to content

Commit

Permalink
faster layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Apr 19, 2022
1 parent 0bdd074 commit 00b9763
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Mem efficient attention, FW pass [#267]
- MHA benchmark
- MLP benchmark
- Move all triton kernels to triton v2 [#272]

## [0.0.10] - 2022-03-14
### Fixed
Expand Down
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/layer_norm/LayerNorm_FW+BW_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/layer_norm/LayerNorm_FW_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion xformers/triton/k_fused_matmul_bw.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=4),
triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=16),
triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=8),
],
key=["N"],
)
Expand Down
24 changes: 10 additions & 14 deletions xformers/triton/k_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,6 @@
import triton.language as tl


@triton.jit
def _store(y, Y, stride, N, BLOCK_SIZE_N: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)

y_ptrs = Y + row * stride + cols
tl.store(y_ptrs, y, mask=cols < N)


# fmt: off
@triton.jit
def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, affine: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
Expand Down Expand Up @@ -87,16 +78,16 @@ def layer_norm_bwd_dx_fused(
dy_ptrs = DY + row * stride + cols

# load data to SRAM
x = tl.load(x_ptrs, mask=mask, other=0).to(tl.float32)
dy = tl.load(dy_ptrs, mask=mask, other=0).to(tl.float32)
x = tl.load(x_ptrs, mask=mask, other=0)
dy = tl.load(dy_ptrs, mask=mask, other=0)
mean = tl.load(M + row)
rstd = tl.load(V + row)

# compute dx
xhat = (x - mean) * rstd

if affine:
w = tl.load(W + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask, other=0)
wdy = w * dy
else:
wdy = dy
Expand All @@ -108,6 +99,7 @@ def layer_norm_bwd_dx_fused(
dx = (wdy - (xhat * mean1 + mean2)) * rstd

# write-back dx
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N # re-materialize the mask to save registers
dx_ptrs = DX + row * stride + cols
tl.store(dx_ptrs, dx, mask=mask)
Expand Down Expand Up @@ -172,12 +164,16 @@ def layer_norm_bwd_dwdb(
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
offs = rows[:, None] * N + cols[None, :]
mask_rm = rows < M

dw += tl.load(DW + offs, mask=(rows[:, None] < M) & mask_cols[None, :], other=0.0)
db += tl.load(DB + offs, mask=(rows[:, None] < M) & mask_cols[None, :], other=0.0)
dw += tl.load(DW + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)
db += tl.load(DB + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0)

sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)

cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask_cols = cols < N

tl.store(FINAL_DW + cols, sum_dw, mask=mask_cols)
tl.store(FINAL_DB + cols, sum_db, mask=mask_cols)
11 changes: 8 additions & 3 deletions xformers/triton/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,19 @@ def backward(
M, N = x.size()

# heuristics for amount of parallel reduction stream for DG/DB
GROUP_SIZE_M = 64
GROUP_SIZE_M = 32
if N <= 8192:
GROUP_SIZE_M = 96
GROUP_SIZE_M = 64
if N <= 4096:
GROUP_SIZE_M = 96
if N <= 2048:
GROUP_SIZE_M = 128
if N <= 1024:
GROUP_SIZE_M = 256

if dy.dtype == torch.float32:
GROUP_SIZE_M = GROUP_SIZE_M // 2

# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda")
t_args = {"dtype": x.dtype, "device": x.device}
Expand Down Expand Up @@ -150,7 +155,7 @@ def grid(meta):
GROUP_SIZE_M,
N,
BLOCK_SIZE_M=32,
BLOCK_SIZE_N=128
BLOCK_SIZE_N=64
)
# fmt: on

Expand Down

0 comments on commit 00b9763

Please sign in to comment.