From da2fd5d4cc272d031d7369abd79cb37a6b5cfbc7 Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 31 Oct 2024 14:02:49 +0000 Subject: [PATCH 01/18] Initial commit --- src/liger_kernel/ops/group_norm.py | 241 ++++++++++++++++++++ src/liger_kernel/transformers/functional.py | 2 + src/liger_kernel/transformers/group_norm.py | 47 ++++ test/transformers/test_group_norm.py | 101 ++++++++ 4 files changed, 391 insertions(+) create mode 100644 src/liger_kernel/ops/group_norm.py create mode 100644 src/liger_kernel/transformers/group_norm.py create mode 100644 test/transformers/test_group_norm.py diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py new file mode 100644 index 000000000..3ff40b06d --- /dev/null +++ b/src/liger_kernel/ops/group_norm.py @@ -0,0 +1,241 @@ +import math +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ( + calculate_settings, + compare_version, + ensure_contiguous, +) + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _group_norm_forward_kernel( + Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size) + Y_row_stride, # stride of each row in output + Y_col_stride, # stride of each column in output + X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_row_stride, # stride of each row in mean + Mean_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + RSTD_row_stride, # stride of each row in rstd + RSTD_col_stride, # stride of each column in rstd + W_ptr, # pointer to weights, shape (n_groups) + B_ptr, # pointer to bias, shape (n_groups) + hidden_size, + num_channels, + num_rows, + eps, + BLOCK_SIZE: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + """ + row_idx = tl.program_id(0) + col_idx = tl.program_id(1) + + hidden_size_offsets = tl.arange(0, BLOCK_SIZE) + channel_offsets = tl.arange(0, num_channels) + hidden_size_mask = hidden_size_offsets < hidden_size + + Y_ptr += row_idx * Y_row_stride + col_idx * Y_col_stride + X_ptr += row_idx * X_row_stride + col_idx * X_col_stride + Mean_ptr += row_idx * Mean_row_stride + col_idx * Mean_col_stride + RSTD_ptr += row_idx * RSTD_row_stride + col_idx * RSTD_col_stride + + X_row = tl.load(X_ptr + hidden_size_offsets, mask=hidden_size_mask, other=0) + W = tl.load(W_ptr + channel_offsets) + B = tl.load(B_ptr + channel_offsets) + + mean = tl.sum(X_row, axis=-1) / hidden_size + diff = X_row - mean + var = tl.sum(diff * diff, axis=-1) / hidden_size + rstd = rsqrt(var + eps) + + tl.store(Mean_ptr, mean) + tl.store(RSTD_ptr, rstd) + X_row_reshaped = tl.view(num_rows, num_channels, hidden_size) + Y_row = (X_row_reshaped - mean) * rstd * W + B + + tl.store(Y_ptr + hidden_size_offsets, Y_row, mask=hidden_size_mask, other=0) + + +@triton.jit +def _group_norm_backward_kernel( + X_ptr, # pointer to input, shape (n_rows, n_cols) + W_ptr, # pointer to weights, shape (n_cols,) + Mean_ptr, # pointer to mean, shape (n_rows,) + RSTD_ptr, # pointer to rstd, shape (n_rows,) + DX_ptr, # pointer to input grad, shape (n_rows, n_cols) + DW_ptr, # pointer to weights grad, shape (n_cols,) + DB_ptr, # pointer to bias grad, shape (n_cols,) + DY_ptr, # pointer to output grad, shape (n_rows, n_cols) + stride_x, # stride of each row in input + stride_dx, # stride of each row in input grad + stride_dw, # stride of each row in weights grad + stride_db, # stride of each row in bias grad + stride_dy, # stride of each row in output grad + n_rows, + n_cols, + rows_per_program: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, +): + """ + References: + https://nn.labml.ai/normalization/group_norm/index.html + """ + row_block_id = tl.program_id(0) + row_start = row_block_id * rows_per_program + row_end = min((row_block_id + 1) * rows_per_program, n_rows) + cols = tl.arange(0, BLOCK_SIZE) + mask = cols < n_cols + + dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + X_ptr += row_start * stride_x + Mean_ptr += row_start + RSTD_ptr += row_start + DX_ptr += row_start * stride_dx + DY_ptr += row_start * stride_dy + + for _ in range(row_start, row_end): + x = tl.load(X_ptr + cols, mask=mask, other=0.0) + w = tl.load(W_ptr + cols, mask=mask, other=0.0) + dy = tl.load(DY_ptr + cols, mask=mask, other=0.0) + mean = tl.load(Mean_ptr) + rstd = tl.load(RSTD_ptr) + + x_hat = (x - mean) * rstd + wdy = w * dy + c1 = tl.sum(x_hat * wdy, axis=0) / n_cols + c2 = tl.sum(wdy, axis=0) / n_cols + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + cols, dx.to(dtype), mask=mask) + + dw_row += dy * x_hat + db_row += dy + + X_ptr += stride_x + Mean_ptr += 1 + RSTD_ptr += 1 + DX_ptr += stride_dx + DY_ptr += stride_dy + + tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask) + tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask) + + +def group_norm_forward(X, num_channels, num_groups, W, B, eps): + shape = X.shape + batch_size = shape[0] + # Reshape X so that the mean and std are computed across the groups + X = X.view(batch_size, num_groups, -1) + hidden_size = X.shape[-1] + BLOCK_SIZE, num_warps = calculate_settings(hidden_size) + Y = torch.empty((batch_size, num_channels, hidden_size), dtype=X.dtype, device=X.device) + Mean = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) + RSTD = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) + + _group_norm_forward_kernel[(batch_size, num_groups)]( + Y, + Y.stride(0), + Y.stride(1), + X, + X.stride(0), + X.stride(1), + Mean, + Mean.stride(0), + Mean.stride(1), + RSTD, + RSTD.stride(0), + RSTD.stride(1), + W, + B, + hidden_size, + batch_size, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return Y, X, Mean, RSTD, BLOCK_SIZE, num_warps + + +def group_norm_backward(dY, X, W, B, Mean, RSTD): + shape = dY.shape + dim = shape[-1] + dY = dY.view(-1, dim) + n_rows, n_cols = dY.shape + + DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) + sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count + _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) + + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + if n_cols > BLOCK_SIZE: + raise RuntimeError("This group norm doesn't support feature dim >= 64KB.") + + rows_per_program = math.ceil(n_rows / sm_count) + grid = (sm_count,) + triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 + _group_norm_backward_kernel[grid]( + X, + W, + Mean, + RSTD, + DX, + _DW, + _DB, + dY, + X.stride(0), + DX.stride(0), + _DW.stride(0), + _DB.stride(0), + dY.stride(0), + n_rows, + n_cols, + rows_per_program, + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype, + ) + + DW = _DW.sum(dim=0).to(W.dtype) + DB = _DB.sum(dim=0).to(W.dtype) + + DX = DX.view(*shape) + return DX, DW, DB + + +class LigerGroupNormFunction(torch.autograd.Function): + @staticmethod + @ensure_contiguous + def forward(ctx, X, num_channels, num_groups, affine_scaling_weight, affine_shifting_bias, eps): + Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = group_norm_forward(X, num_channels, num_groups, affine_scaling_weight, affine_shifting_bias, eps) + ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) + return Y + + @staticmethod + @ensure_contiguous + def backward(ctx, dY): + X, W, B, Mean, RSTD = ctx.saved_tensors + DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD) + return DX, DW, DB, None diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index f160887b8..39672000f 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -10,6 +10,7 @@ from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.ops.swiglu import LigerSiLUMulFunction +from liger_kernel.ops.group_norm import LigerGroupNormFunction liger_swiglu = LigerSiLUMulFunction.apply liger_cross_entropy = LigerCrossEntropyFunction.apply @@ -21,3 +22,4 @@ liger_kl_div = LigerKLDivLossFunction.apply liger_jsd = LigerJSDFunction.apply liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply +liger_group_norm = LigerGroupNormFunction.apply \ No newline at end of file diff --git a/src/liger_kernel/transformers/group_norm.py b/src/liger_kernel/transformers/group_norm.py new file mode 100644 index 000000000..102b13668 --- /dev/null +++ b/src/liger_kernel/transformers/group_norm.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn + +from liger_kernel.ops.group_norm import LigerGroupNormFunction + + +class LigerGroupNorm(nn.Module): + def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones"): + """ + A Group Normalization layer. + Args: + num_channels (int): Number of channels in the input tensor. + num_groups (int): Number of groups to divide the channels into. + eps (float, optional): A value added to the denominator for numerical stability. Default: 1e-6. + bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``. + init_fn (str, optional): Initialization function for the learnable parameters. Default: "ones". + """ + super().__init__() + assert init_fn in [ + "ones", + "zeros", + ], f"init_fn must be either 'ones' or 'zeros', got {init_fn}" + + assert ( + num_channels % num_groups == 0 + ), f"Number of channels {num_channels} must be divisible by num_groups {num_groups}" + self.num_channels = num_channels + self.num_groups = num_groups + self.eps = eps + self.affine_scaling_weight = nn.Parameter( + torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels) + ) + self.affine_shifting_bias = nn.Parameter( + torch.randn(num_channels) if bias else torch.zeros(num_channels) + ) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # hidden_states: (batch_size, num_channels, *) + assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}" + assert hidden_states.size(1) == self.num_channels, f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}" + return LigerGroupNormFunction.apply( + hidden_states, self.num_channels, self.num_groups, self.affine_scaling_weight, self.affine_shifting_bias, self.variance_epsilon + ) + + def extra_repr(self): + return f"{self.hidden_size}, num_channels={self.num_channels}, num_groups={self.num_groups}, eps={self.eps}" diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py new file mode 100644 index 000000000..516faa0f2 --- /dev/null +++ b/test/transformers/test_group_norm.py @@ -0,0 +1,101 @@ +import pytest +import torch + +from liger_kernel.ops.group_norm import LigerGroupNormFunction +from liger_kernel.transformers.functional import liger_group_norm +from liger_kernel.transformers.group_norm import LigerGroupNorm + + +@pytest.mark.parametrize( + "batch_size, num_channels, num_groups, hidden_size", + [ + (1, 3, 1, 128), + (2, 4, 2, 128), + (16, 12, 3, 128), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-5, 1e-5), + ], +) +def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol): + torch.manual_seed(0) + + x = torch.randn( + batch_size, num_channels, hidden_size, dtype=dtype, device="cuda", requires_grad=True + ) + liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).cuda() + torch_ln = torch.nn.GroupNorm(num_channels, num_groups, hidden_size, eps=1e-6).to(dtype).cuda() + + with torch.no_grad(): + torch_ln.weight.copy_(liger_ln.weight) + torch_ln.bias.copy_(liger_ln.bias) + + liger_output = liger_ln(x) + torch_output = torch_ln(x) + + assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) + + # grad_output = torch.randn_like(x) + # liger_output.backward(grad_output, retain_graph=True) + # torch_output.backward(grad_output, retain_graph=True) + + # assert torch.allclose(x.grad, x.grad, atol=atol, rtol=rtol) + # assert torch.allclose( + # liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol + # ) + # assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) + + +# @pytest.mark.parametrize( +# "hidden_size", +# [8, 41], +# ) +# @pytest.mark.parametrize( +# "batch_size, seq_len", +# [ +# (2, 2), +# (9, 7), +# ], +# ) +# @pytest.mark.parametrize( +# "dtype, atol, rtol", +# [ +# (torch.float32, 1e-5, 1e-5), +# ], +# ) +# def test_liger_group_norm_functional( +# hidden_size, batch_size, seq_len, dtype, atol, rtol +# ): +# torch.manual_seed(0) + +# input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + +# x1 = input.clone().requires_grad_(True) +# x2 = input.clone().requires_grad_(True) + +# w = torch.randn(hidden_size, device="cuda", dtype=dtype) + +# w1 = w.clone().requires_grad_(True) +# w2 = w.clone().requires_grad_(True) + +# b = torch.randn(hidden_size, device="cuda", dtype=dtype) + +# b1 = b.clone().requires_grad_(True) +# b2 = b.clone().requires_grad_(True) + +# y1 = liger_group_norm(x1, w1, b1, 1e-6) +# y2 = LigergroupNormFunction.apply(x2, w2, b2, 1e-6) + +# assert torch.allclose(y1, y2, atol=atol, rtol=rtol) + +# grad_output = torch.randn_like(y2) + +# y1.backward(grad_output, retain_graph=True) +# y2.backward(grad_output, retain_graph=True) + +# assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) +# assert torch.allclose(w1.grad, w2.grad, atol=atol, rtol=rtol) +# assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol) From 50880c9e3fd6ddb4a3334c3f6470accfec1ed046 Mon Sep 17 00:00:00 2001 From: pramodith Date: Thu, 31 Oct 2024 17:46:30 +0000 Subject: [PATCH 02/18] Forward pass works --- src/liger_kernel/ops/group_norm.py | 71 +++++++++++++++++----------- test/transformers/test_group_norm.py | 12 ++--- 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 3ff40b06d..5d7b53cb2 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -42,41 +42,51 @@ def _group_norm_forward_kernel( num_channels, num_rows, eps, - BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr ): """ References: https://nn.labml.ai/normalization/group_norm/index.html """ - row_idx = tl.program_id(0) - col_idx = tl.program_id(1) - - hidden_size_offsets = tl.arange(0, BLOCK_SIZE) - channel_offsets = tl.arange(0, num_channels) - hidden_size_mask = hidden_size_offsets < hidden_size - - Y_ptr += row_idx * Y_row_stride + col_idx * Y_col_stride - X_ptr += row_idx * X_row_stride + col_idx * X_col_stride - Mean_ptr += row_idx * Mean_row_stride + col_idx * Mean_col_stride - RSTD_ptr += row_idx * RSTD_row_stride + col_idx * RSTD_col_stride - - X_row = tl.load(X_ptr + hidden_size_offsets, mask=hidden_size_mask, other=0) - W = tl.load(W_ptr + channel_offsets) - B = tl.load(B_ptr + channel_offsets) - - mean = tl.sum(X_row, axis=-1) / hidden_size - diff = X_row - mean - var = tl.sum(diff * diff, axis=-1) / hidden_size - rstd = rsqrt(var + eps) + batch_idx = tl.program_id(0) + group_idx = tl.program_id(1) - tl.store(Mean_ptr, mean) - tl.store(RSTD_ptr, rstd) - X_row_reshaped = tl.view(num_rows, num_channels, hidden_size) - Y_row = (X_row_reshaped - mean) * rstd * W + B + X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride + Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride + + # Compute mean + sum = 0.0 + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + sum += tl.sum(X) + + mean = sum / hidden_size + tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, mean) + + # Compute variance + variance = 0.0 + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + diff = X - mean + variance += tl.sum(diff * diff) + + variance = variance / (hidden_size) + std = tl.sqrt(variance + eps) - tl.store(Y_ptr + hidden_size_offsets, Y_row, mask=hidden_size_mask, other=0) + tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, variance) + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + Y = (X - mean) / std + tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) + @triton.jit def _group_norm_backward_kernel( X_ptr, # pointer to input, shape (n_rows, n_cols) @@ -151,7 +161,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): X = X.view(batch_size, num_groups, -1) hidden_size = X.shape[-1] BLOCK_SIZE, num_warps = calculate_settings(hidden_size) - Y = torch.empty((batch_size, num_channels, hidden_size), dtype=X.dtype, device=X.device) + Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) Mean = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) RSTD = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) @@ -171,11 +181,18 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): W, B, hidden_size, + num_channels, batch_size, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, ) + + Y = Y.view(*shape) + affine_shape = [1] * len(shape) + affine_shape[1] = num_channels + Y = Y * W.view(affine_shape) + B.view(affine_shape) + return Y, X, Mean, RSTD, BLOCK_SIZE, num_warps diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 516faa0f2..73ba998b3 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -9,9 +9,9 @@ @pytest.mark.parametrize( "batch_size, num_channels, num_groups, hidden_size", [ - (1, 3, 1, 128), + (1, 4, 2, 8), (2, 4, 2, 128), - (16, 12, 3, 128), + (16, 12, 3, 4096), ], ) @pytest.mark.parametrize( @@ -26,16 +26,16 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty x = torch.randn( batch_size, num_channels, hidden_size, dtype=dtype, device="cuda", requires_grad=True ) - liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).cuda() - torch_ln = torch.nn.GroupNorm(num_channels, num_groups, hidden_size, eps=1e-6).to(dtype).cuda() + liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).cuda() + torch_ln = torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6).to(dtype).cuda() + with torch.no_grad(): torch_ln.weight.copy_(liger_ln.weight) torch_ln.bias.copy_(liger_ln.bias) - liger_output = liger_ln(x) + liger_output = liger_ln(x,) torch_output = torch_ln(x) - assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) # grad_output = torch.randn_like(x) From 24e3201ee9cddd66f6f422e818588fe4d1bbaf9f Mon Sep 17 00:00:00 2001 From: pramodith Date: Fri, 1 Nov 2024 16:18:11 +0000 Subject: [PATCH 03/18] Backward works partially --- src/liger_kernel/ops/group_norm.py | 202 ++++++++++---------- src/liger_kernel/transformers/group_norm.py | 6 +- test/transformers/test_group_norm.py | 65 +------ 3 files changed, 107 insertions(+), 166 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 5d7b53cb2..8e4ea332d 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -36,11 +36,7 @@ def _group_norm_forward_kernel( RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) RSTD_row_stride, # stride of each row in rstd RSTD_col_stride, # stride of each column in rstd - W_ptr, # pointer to weights, shape (n_groups) - B_ptr, # pointer to bias, shape (n_groups) hidden_size, - num_channels, - num_rows, eps, BLOCK_SIZE: tl.constexpr ): @@ -75,84 +71,97 @@ def _group_norm_forward_kernel( variance += tl.sum(diff * diff) variance = variance / (hidden_size) - std = tl.sqrt(variance + eps) - - tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, variance) + # 1/std + rstd = rsqrt(variance + eps) + tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) + # Normalize for i in range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) mask = hidden_size_offsets < hidden_size X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) - Y = (X - mean) / std + Y = (X - mean) * rstd tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) @triton.jit def _group_norm_backward_kernel( - X_ptr, # pointer to input, shape (n_rows, n_cols) - W_ptr, # pointer to weights, shape (n_cols,) - Mean_ptr, # pointer to mean, shape (n_rows,) - RSTD_ptr, # pointer to rstd, shape (n_rows,) - DX_ptr, # pointer to input grad, shape (n_rows, n_cols) - DW_ptr, # pointer to weights grad, shape (n_cols,) - DB_ptr, # pointer to bias grad, shape (n_cols,) - DY_ptr, # pointer to output grad, shape (n_rows, n_cols) - stride_x, # stride of each row in input - stride_dx, # stride of each row in input grad - stride_dw, # stride of each row in weights grad - stride_db, # stride of each row in bias grad - stride_dy, # stride of each row in output grad - n_rows, - n_cols, - rows_per_program: tl.constexpr, + X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size) + X_row_stride, # stride of each row in input + X_col_stride, # stride of each column in input + W_ptr, # pointer to weights, shape (n_channels) + Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_ptr_col_stride, # stride of each column in mean + RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) + DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size) + DW_ptr, # pointer to weights grad, shape (n_channels) + DW_col_stride, # stride of each column in weights + DB_ptr, # pointer to bias grad, shape (n_channels) + UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) + hidden_size: tl.constexpr, # hidden size + num_groups: tl.constexpr, # number of groups in group norm BLOCK_SIZE: tl.constexpr, - dtype: tl.constexpr, ): """ References: https://nn.labml.ai/normalization/group_norm/index.html """ - row_block_id = tl.program_id(0) - row_start = row_block_id * rows_per_program - row_end = min((row_block_id + 1) * rows_per_program, n_rows) - cols = tl.arange(0, BLOCK_SIZE) - mask = cols < n_cols - - dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) - - X_ptr += row_start * stride_x - Mean_ptr += row_start - RSTD_ptr += row_start - DX_ptr += row_start * stride_dx - DY_ptr += row_start * stride_dy - - for _ in range(row_start, row_end): - x = tl.load(X_ptr + cols, mask=mask, other=0.0) - w = tl.load(W_ptr + cols, mask=mask, other=0.0) - dy = tl.load(DY_ptr + cols, mask=mask, other=0.0) - mean = tl.load(Mean_ptr) - rstd = tl.load(RSTD_ptr) + batch_idx = tl.program_id(0) + channel_idx = tl.program_id(1) - x_hat = (x - mean) * rstd - wdy = w * dy - c1 = tl.sum(x_hat * wdy, axis=0) / n_cols - c2 = tl.sum(wdy, axis=0) / n_cols - dx = (wdy - (x_hat * c1 + c2)) * rstd - tl.store(DX_ptr + cols, dx.to(dtype), mask=mask) + group_idx = channel_idx // num_groups - dw_row += dy * x_hat - db_row += dy + # X_col_stide will correspond to the number of groups + X_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride + DX_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride + UPSTREAM_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride - X_ptr += stride_x - Mean_ptr += 1 - RSTD_ptr += 1 - DX_ptr += stride_dx - DY_ptr += stride_dy + DW_ptr += batch_idx * X_row_stride + channel_idx * DW_col_stride + DB_ptr += batch_idx * X_row_stride + channel_idx * DW_col_stride + # Mean and rstd are the same shape so have the same strides + mean = tl.load(Mean_ptr + batch_idx * X_row_stride + group_idx * Mean_ptr_col_stride) + rstd = tl.load(RSTD_ptr + batch_idx * X_row_stride + group_idx * Mean_ptr_col_stride) + W = tl.load(W_ptr + group_idx) + + DW = 0.0 + DB = 0.0 - tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask) - tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask) + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + UPSTREAM_grad = tl.load(UPSTREAM_ptr + hidden_size_offsets, mask=mask, other=0.0) + """ + Y = (X - mean) * rstd + + h(x) = rstd = 1/(sqrt(var + eps)) + f(x) = x * h(x) = X * rstd + g(x) = - mean * h(x) = - mean * rstd + + Y = f(x) + g(x) + dy_dx = df_dx + dg_dx + """ + + # dh_dx = -0.5 * (rstd**3) * dvar_dx + c1 = 1 / hidden_size + c2 = X - mean + dmean_dx = c1 + dvar_dx = 2 * c2 * c1 + drstd_dx = -0.5 * (rstd*rstd*rstd) * dvar_dx + + df_dx = rstd + X * drstd_dx + dg_dx = - (dmean_dx * rstd + mean * drstd_dx) + dY_dx = df_dx + dg_dx + + DX = W * UPSTREAM_grad * dY_dx + + DW += tl.sum(UPSTREAM_grad * X) + DB += tl.sum(UPSTREAM_grad) + tl.store(DX_ptr + hidden_size_offsets, DX, mask=mask) + + tl.store(DW_ptr, DW) + tl.store(DB_ptr, DB) def group_norm_forward(X, num_channels, num_groups, W, B, eps): shape = X.shape @@ -165,7 +174,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): Mean = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) RSTD = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) - _group_norm_forward_kernel[(batch_size, num_groups)]( + _group_norm_forward_kernel[(batch_size, num_channels)]( Y, Y.stride(0), Y.stride(1), @@ -178,11 +187,7 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): RSTD, RSTD.stride(0), RSTD.stride(1), - W, - B, hidden_size, - num_channels, - batch_size, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, @@ -196,57 +201,44 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): return Y, X, Mean, RSTD, BLOCK_SIZE, num_warps -def group_norm_backward(dY, X, W, B, Mean, RSTD): +def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): shape = dY.shape - dim = shape[-1] - dY = dY.view(-1, dim) - n_rows, n_cols = dY.shape - - DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) - sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count - _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) - _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device) - - BLOCK_SIZE, num_warps = calculate_settings(n_cols) - if n_cols > BLOCK_SIZE: - raise RuntimeError("This group norm doesn't support feature dim >= 64KB.") - - rows_per_program = math.ceil(n_rows / sm_count) - grid = (sm_count,) - triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 - _group_norm_backward_kernel[grid]( + print(dY) + batch_size = shape[0] + hidden_size = dY.shape[-1] + DX = torch.empty((batch_size, num_channels, hidden_size), dtype=X.dtype, device=X.device) + DW = torch.empty((batch_size, num_channels), dtype=W.dtype, device=W.device) + DB = torch.empty((batch_size, num_channels), dtype=B.dtype, device=B.device) + BLOCK_SIZE, num_warps = calculate_settings(hidden_size) + _group_norm_backward_kernel[(batch_size, num_channels)]( X, + X.stride(0), + X.stride(1), W, Mean, + Mean.stride(1), RSTD, DX, - _DW, - _DB, + DW, + DW.stride(1), + DB, dY, - X.stride(0), - DX.stride(0), - _DW.stride(0), - _DB.stride(0), - dY.stride(0), - n_rows, - n_cols, - rows_per_program, - BLOCK_SIZE=BLOCK_SIZE, - dtype=triton_dtype, + hidden_size, + num_groups, + BLOCK_SIZE=BLOCK_SIZE ) - - DW = _DW.sum(dim=0).to(W.dtype) - DB = _DB.sum(dim=0).to(W.dtype) - - DX = DX.view(*shape) - return DX, DW, DB + print(DB) + print(DW) + return DX, DW.sum(dim=0), DB.sum(dim=0) class LigerGroupNormFunction(torch.autograd.Function): @staticmethod @ensure_contiguous - def forward(ctx, X, num_channels, num_groups, affine_scaling_weight, affine_shifting_bias, eps): + def forward(ctx, X, affine_scaling_weight, affine_shifting_bias, num_channels, num_groups, eps): Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = group_norm_forward(X, num_channels, num_groups, affine_scaling_weight, affine_shifting_bias, eps) + ctx.num_channels = num_channels + ctx.num_groups = num_groups ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) return Y @@ -254,5 +246,5 @@ def forward(ctx, X, num_channels, num_groups, affine_scaling_weight, affine_shif @ensure_contiguous def backward(ctx, dY): X, W, B, Mean, RSTD = ctx.saved_tensors - DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD) - return DX, DW, DB, None + DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups) + return DX, DW, DB, None, None, None diff --git a/src/liger_kernel/transformers/group_norm.py b/src/liger_kernel/transformers/group_norm.py index 102b13668..2b2631507 100644 --- a/src/liger_kernel/transformers/group_norm.py +++ b/src/liger_kernel/transformers/group_norm.py @@ -27,10 +27,10 @@ def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones self.num_channels = num_channels self.num_groups = num_groups self.eps = eps - self.affine_scaling_weight = nn.Parameter( + self.weight = nn.Parameter( torch.ones(num_channels) if init_fn == "ones" else torch.zeros(num_channels) ) - self.affine_shifting_bias = nn.Parameter( + self.bias = nn.Parameter( torch.randn(num_channels) if bias else torch.zeros(num_channels) ) self.variance_epsilon = eps @@ -40,7 +40,7 @@ def forward(self, hidden_states): assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}" assert hidden_states.size(1) == self.num_channels, f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}" return LigerGroupNormFunction.apply( - hidden_states, self.num_channels, self.num_groups, self.affine_scaling_weight, self.affine_shifting_bias, self.variance_epsilon + hidden_states, self.weight, self.bias, self.num_channels, self.num_groups, self.variance_epsilon ) def extra_repr(self): diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 73ba998b3..14b195b11 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize( "batch_size, num_channels, num_groups, hidden_size", [ - (1, 4, 2, 8), + (1, 2, 1, 4), (2, 4, 2, 128), (16, 12, 3, 4096), ], @@ -18,6 +18,7 @@ "dtype, atol, rtol", [ (torch.float32, 1e-5, 1e-5), + (torch.float16, 1e-3, 1e-3), ], ) def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol): @@ -38,64 +39,12 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty torch_output = torch_ln(x) assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) - # grad_output = torch.randn_like(x) - # liger_output.backward(grad_output, retain_graph=True) - # torch_output.backward(grad_output, retain_graph=True) + grad_output = torch.randn_like(x) + liger_output.backward(grad_output, retain_graph=True) + torch_output.backward(grad_output, retain_graph=True) - # assert torch.allclose(x.grad, x.grad, atol=atol, rtol=rtol) + assert torch.allclose(x.grad, x.grad, atol=atol, rtol=rtol) # assert torch.allclose( # liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol # ) - # assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) - - -# @pytest.mark.parametrize( -# "hidden_size", -# [8, 41], -# ) -# @pytest.mark.parametrize( -# "batch_size, seq_len", -# [ -# (2, 2), -# (9, 7), -# ], -# ) -# @pytest.mark.parametrize( -# "dtype, atol, rtol", -# [ -# (torch.float32, 1e-5, 1e-5), -# ], -# ) -# def test_liger_group_norm_functional( -# hidden_size, batch_size, seq_len, dtype, atol, rtol -# ): -# torch.manual_seed(0) - -# input = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") - -# x1 = input.clone().requires_grad_(True) -# x2 = input.clone().requires_grad_(True) - -# w = torch.randn(hidden_size, device="cuda", dtype=dtype) - -# w1 = w.clone().requires_grad_(True) -# w2 = w.clone().requires_grad_(True) - -# b = torch.randn(hidden_size, device="cuda", dtype=dtype) - -# b1 = b.clone().requires_grad_(True) -# b2 = b.clone().requires_grad_(True) - -# y1 = liger_group_norm(x1, w1, b1, 1e-6) -# y2 = LigergroupNormFunction.apply(x2, w2, b2, 1e-6) - -# assert torch.allclose(y1, y2, atol=atol, rtol=rtol) - -# grad_output = torch.randn_like(y2) - -# y1.backward(grad_output, retain_graph=True) -# y2.backward(grad_output, retain_graph=True) - -# assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) -# assert torch.allclose(w1.grad, w2.grad, atol=atol, rtol=rtol) -# assert torch.allclose(b1.grad, b2.grad, atol=atol, rtol=rtol) + assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) \ No newline at end of file From ba087fb79c2d2d97f1192980958f749940842509 Mon Sep 17 00:00:00 2001 From: pramodith Date: Fri, 1 Nov 2024 18:13:19 +0000 Subject: [PATCH 04/18] Fixed some edge cases --- src/liger_kernel/ops/group_norm.py | 41 ++++++++++++++-------------- test/transformers/test_group_norm.py | 34 ++++++++++++++--------- 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 8e4ea332d..1d1b26137 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -21,6 +21,7 @@ else: from triton.language.math import rsqrt +MAX_FUSED_SIZE = 65536 @triton.jit def _group_norm_forward_kernel( @@ -66,14 +67,15 @@ def _group_norm_forward_kernel( for i in range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) mask = hidden_size_offsets < hidden_size - X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + # We need to mask out of index with mean to ensure that the variance remains unaffected + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=mean) diff = X - mean variance += tl.sum(diff * diff) - variance = variance / (hidden_size) + variance = variance / hidden_size # 1/std rstd = rsqrt(variance + eps) - tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) + tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, variance) # Normalize for i in range(0, hidden_size, BLOCK_SIZE): @@ -123,13 +125,13 @@ def _group_norm_backward_kernel( rstd = tl.load(RSTD_ptr + batch_idx * X_row_stride + group_idx * Mean_ptr_col_stride) W = tl.load(W_ptr + group_idx) - DW = 0.0 - DB = 0.0 + dW = 0.0 + dB = 0.0 for i in range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) mask = hidden_size_offsets < hidden_size - X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=mean) UPSTREAM_grad = tl.load(UPSTREAM_ptr + hidden_size_offsets, mask=mask, other=0.0) """ Y = (X - mean) * rstd @@ -153,23 +155,24 @@ def _group_norm_backward_kernel( dg_dx = - (dmean_dx * rstd + mean * drstd_dx) dY_dx = df_dx + dg_dx - DX = W * UPSTREAM_grad * dY_dx + dX = W * UPSTREAM_grad * dY_dx - DW += tl.sum(UPSTREAM_grad * X) - DB += tl.sum(UPSTREAM_grad) + dW += tl.sum(UPSTREAM_grad * X) + dB += tl.sum(UPSTREAM_grad) - tl.store(DX_ptr + hidden_size_offsets, DX, mask=mask) + tl.store(DX_ptr + hidden_size_offsets, dX, mask=mask) - tl.store(DW_ptr, DW) - tl.store(DB_ptr, DB) + tl.store(DW_ptr, dW) + tl.store(DB_ptr, dB) def group_norm_forward(X, num_channels, num_groups, W, B, eps): shape = X.shape batch_size = shape[0] + print(X.stride(1)) # Reshape X so that the mean and std are computed across the groups X = X.view(batch_size, num_groups, -1) hidden_size = X.shape[-1] - BLOCK_SIZE, num_warps = calculate_settings(hidden_size) + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) Mean = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) RSTD = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) @@ -189,27 +192,24 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): RSTD.stride(1), hidden_size, eps, - BLOCK_SIZE=BLOCK_SIZE, - num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE ) Y = Y.view(*shape) affine_shape = [1] * len(shape) affine_shape[1] = num_channels Y = Y * W.view(affine_shape) + B.view(affine_shape) - - return Y, X, Mean, RSTD, BLOCK_SIZE, num_warps + return Y, X.view(*shape), Mean, RSTD, BLOCK_SIZE def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): shape = dY.shape - print(dY) batch_size = shape[0] hidden_size = dY.shape[-1] DX = torch.empty((batch_size, num_channels, hidden_size), dtype=X.dtype, device=X.device) DW = torch.empty((batch_size, num_channels), dtype=W.dtype, device=W.device) DB = torch.empty((batch_size, num_channels), dtype=B.dtype, device=B.device) - BLOCK_SIZE, num_warps = calculate_settings(hidden_size) + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) _group_norm_backward_kernel[(batch_size, num_channels)]( X, X.stride(0), @@ -228,7 +228,6 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): BLOCK_SIZE=BLOCK_SIZE ) print(DB) - print(DW) return DX, DW.sum(dim=0), DB.sum(dim=0) @@ -236,7 +235,7 @@ class LigerGroupNormFunction(torch.autograd.Function): @staticmethod @ensure_contiguous def forward(ctx, X, affine_scaling_weight, affine_shifting_bias, num_channels, num_groups, eps): - Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = group_norm_forward(X, num_channels, num_groups, affine_scaling_weight, affine_shifting_bias, eps) + Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(X, num_channels, num_groups, affine_scaling_weight, affine_shifting_bias, eps) ctx.num_channels = num_channels ctx.num_groups = num_groups ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 14b195b11..54f52da44 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -9,25 +9,27 @@ @pytest.mark.parametrize( "batch_size, num_channels, num_groups, hidden_size", [ - (1, 2, 1, 4), - (2, 4, 2, 128), - (16, 12, 3, 4096), + (2, 1, 1, 4), + # (2, 4, 2, 128), + # (16, 12, 3, 4096), ], ) @pytest.mark.parametrize( "dtype, atol, rtol", [ (torch.float32, 1e-5, 1e-5), - (torch.float16, 1e-3, 1e-3), + # (torch.float16, 1e-3, 1e-3), ], ) def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol): torch.manual_seed(0) - x = torch.randn( + _tensor = torch.randn( batch_size, num_channels, hidden_size, dtype=dtype, device="cuda", requires_grad=True ) + liger_x = _tensor.clone().detach().requires_grad_(True) + liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).cuda() torch_ln = torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6).to(dtype).cuda() @@ -35,16 +37,22 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty torch_ln.weight.copy_(liger_ln.weight) torch_ln.bias.copy_(liger_ln.bias) - liger_output = liger_ln(x,) - torch_output = torch_ln(x) + liger_output = liger_ln(liger_x,) + torch_output = torch_ln(_tensor) + assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) - grad_output = torch.randn_like(x) + grad_output = torch.randn_like(_tensor) liger_output.backward(grad_output, retain_graph=True) torch_output.backward(grad_output, retain_graph=True) - - assert torch.allclose(x.grad, x.grad, atol=atol, rtol=rtol) - # assert torch.allclose( - # liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol - # ) + # print(liger_x.grad) + # print(_tensor.grad) + # assert torch.allclose(liger_x.grad, _tensor.grad, atol=atol, rtol=rtol) + # # assert torch.allclose( + # # liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol + # # ) + print(grad_output) + print(grad_output.sum()) + print(liger_ln.bias.grad) + print(torch_ln.bias.grad) assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) \ No newline at end of file From 799cbe1b36291d5d969138be6a5b19edc5098c17 Mon Sep 17 00:00:00 2001 From: pramodith Date: Fri, 1 Nov 2024 19:04:12 +0000 Subject: [PATCH 05/18] Find the current group the right way. --- src/liger_kernel/ops/group_norm.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 1d1b26137..b264a09a1 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -75,7 +75,7 @@ def _group_norm_forward_kernel( variance = variance / hidden_size # 1/std rstd = rsqrt(variance + eps) - tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, variance) + tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) # Normalize for i in range(0, hidden_size, BLOCK_SIZE): @@ -93,15 +93,17 @@ def _group_norm_backward_kernel( X_col_stride, # stride of each column in input W_ptr, # pointer to weights, shape (n_channels) Mean_ptr, # pointer to mean, shape (n_rows, n_groups) + Mean_ptr_row_stride, # stride of each column in mean Mean_ptr_col_stride, # stride of each column in mean RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size) DW_ptr, # pointer to weights grad, shape (n_channels) + DW_row_stride, # stride of each row in weights DW_col_stride, # stride of each column in weights DB_ptr, # pointer to bias grad, shape (n_channels) UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) hidden_size: tl.constexpr, # hidden size - num_groups: tl.constexpr, # number of groups in group norm + channels_per_group: tl.constexpr, # number of groups in group norm BLOCK_SIZE: tl.constexpr, ): """ @@ -111,18 +113,20 @@ def _group_norm_backward_kernel( batch_idx = tl.program_id(0) channel_idx = tl.program_id(1) - group_idx = channel_idx // num_groups + group_idx = channel_idx // channels_per_group # X_col_stide will correspond to the number of groups X_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride DX_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride UPSTREAM_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride - DW_ptr += batch_idx * X_row_stride + channel_idx * DW_col_stride - DB_ptr += batch_idx * X_row_stride + channel_idx * DW_col_stride + # DW and DB have the same shape so have the same strides + DW_ptr += batch_idx * DW_row_stride + channel_idx * DW_col_stride + DB_ptr += batch_idx * DW_row_stride + channel_idx * DW_col_stride + # Mean and rstd are the same shape so have the same strides - mean = tl.load(Mean_ptr + batch_idx * X_row_stride + group_idx * Mean_ptr_col_stride) - rstd = tl.load(RSTD_ptr + batch_idx * X_row_stride + group_idx * Mean_ptr_col_stride) + mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) W = tl.load(W_ptr + group_idx) dW = 0.0 @@ -157,7 +161,8 @@ def _group_norm_backward_kernel( dX = W * UPSTREAM_grad * dY_dx - dW += tl.sum(UPSTREAM_grad * X) + c3 = c2 * rstd + dW += tl.sum(UPSTREAM_grad * c3) dB += tl.sum(UPSTREAM_grad) tl.store(DX_ptr + hidden_size_offsets, dX, mask=mask) @@ -206,6 +211,7 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): shape = dY.shape batch_size = shape[0] hidden_size = dY.shape[-1] + channels_per_group = num_channels // num_groups DX = torch.empty((batch_size, num_channels, hidden_size), dtype=X.dtype, device=X.device) DW = torch.empty((batch_size, num_channels), dtype=W.dtype, device=W.device) DB = torch.empty((batch_size, num_channels), dtype=B.dtype, device=B.device) @@ -216,15 +222,17 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): X.stride(1), W, Mean, + Mean.stride(0), Mean.stride(1), RSTD, DX, DW, + DW.stride(0), DW.stride(1), DB, dY, hidden_size, - num_groups, + channels_per_group, BLOCK_SIZE=BLOCK_SIZE ) print(DB) From 5c29ff034e46b7f05f3f37241460c0792975e2e0 Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Sun, 3 Nov 2024 19:42:24 +0000 Subject: [PATCH 06/18] More pointer bugs. --- src/liger_kernel/ops/group_norm.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index b264a09a1..ac281a97c 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -98,8 +98,6 @@ def _group_norm_backward_kernel( RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) DX_ptr, # pointer to input grad, shape (n_rows, n_groups, hidden_size) DW_ptr, # pointer to weights grad, shape (n_channels) - DW_row_stride, # stride of each row in weights - DW_col_stride, # stride of each column in weights DB_ptr, # pointer to bias grad, shape (n_channels) UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) hidden_size: tl.constexpr, # hidden size @@ -120,14 +118,14 @@ def _group_norm_backward_kernel( DX_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride UPSTREAM_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride - # DW and DB have the same shape so have the same strides - DW_ptr += batch_idx * DW_row_stride + channel_idx * DW_col_stride - DB_ptr += batch_idx * DW_row_stride + channel_idx * DW_col_stride + # We compute the gradients for W and B for the channel the thread is responsible for + DW_ptr += channel_idx + DB_ptr += channel_idx # Mean and rstd are the same shape so have the same strides mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) - W = tl.load(W_ptr + group_idx) + W = tl.load(W_ptr + channel_idx) dW = 0.0 dB = 0.0 From 06e2277fe57cfdb324e2d518c1cd3ab1cc82b40c Mon Sep 17 00:00:00 2001 From: pramodith Date: Mon, 4 Nov 2024 18:31:39 +0000 Subject: [PATCH 07/18] DB and DW work! --- src/liger_kernel/ops/group_norm.py | 106 ++++++++++++++++----------- test/transformers/test_group_norm.py | 34 +++++---- 2 files changed, 83 insertions(+), 57 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index ac281a97c..f7375b2f4 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -6,7 +6,6 @@ import triton.language as tl from liger_kernel.ops.utils import ( - calculate_settings, compare_version, ensure_contiguous, ) @@ -39,7 +38,7 @@ def _group_norm_forward_kernel( RSTD_col_stride, # stride of each column in rstd hidden_size, eps, - BLOCK_SIZE: tl.constexpr + BLOCK_SIZE: tl.constexpr, ): """ References: @@ -52,15 +51,14 @@ def _group_norm_forward_kernel( Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride # Compute mean - sum = 0.0 + s = 0.0 for i in range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) mask = hidden_size_offsets < hidden_size X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) - sum += tl.sum(X) + s += tl.sum(X) - mean = sum / hidden_size - tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, mean) + m = s/hidden_size # Compute variance variance = 0.0 @@ -68,23 +66,27 @@ def _group_norm_forward_kernel( hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) mask = hidden_size_offsets < hidden_size # We need to mask out of index with mean to ensure that the variance remains unaffected - X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=mean) - diff = X - mean + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) + diff = X - m variance += tl.sum(diff * diff) variance = variance / hidden_size # 1/std rstd = rsqrt(variance + eps) - tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) + # Normalize for i in range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) mask = hidden_size_offsets < hidden_size - X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) - Y = (X - mean) * rstd + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) + Y = (X - m) * rstd tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) + + tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m) + tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) + @triton.jit def _group_norm_backward_kernel( @@ -103,6 +105,7 @@ def _group_norm_backward_kernel( hidden_size: tl.constexpr, # hidden size channels_per_group: tl.constexpr, # number of groups in group norm BLOCK_SIZE: tl.constexpr, + dtype: tl.constexpr, ): """ References: @@ -118,13 +121,9 @@ def _group_norm_backward_kernel( DX_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride UPSTREAM_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride - # We compute the gradients for W and B for the channel the thread is responsible for - DW_ptr += channel_idx - DB_ptr += channel_idx - # Mean and rstd are the same shape so have the same strides - mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) - rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + mean = tl.load(Mean_ptr) + rstd = tl.load(RSTD_ptr) W = tl.load(W_ptr + channel_idx) dW = 0.0 @@ -147,40 +146,49 @@ def _group_norm_backward_kernel( """ # dh_dx = -0.5 * (rstd**3) * dvar_dx - c1 = 1 / hidden_size + # c1 = 1 / hidden_size c2 = X - mean - dmean_dx = c1 - dvar_dx = 2 * c2 * c1 - drstd_dx = -0.5 * (rstd*rstd*rstd) * dvar_dx + # c4 = tl.sum(c2) + # dmean_dx = c1 + # dvar_dx = c4 * c1 + # drstd_dx = - (rstd*rstd*rstd) * dvar_dx - df_dx = rstd + X * drstd_dx - dg_dx = - (dmean_dx * rstd + mean * drstd_dx) - dY_dx = df_dx + dg_dx + # df_dx = rstd + X * drstd_dx + # dg_dx = - (dmean_dx * rstd + mean * drstd_dx) + # dY_dx = df_dx + dg_dx - dX = W * UPSTREAM_grad * dY_dx + # dX = W * UPSTREAM_grad * dY_dx - c3 = c2 * rstd - dW += tl.sum(UPSTREAM_grad * c3) + norm = c2 * rstd + dW += tl.sum(UPSTREAM_grad * norm) dB += tl.sum(UPSTREAM_grad) - tl.store(DX_ptr + hidden_size_offsets, dX, mask=mask) + dnorm = UPSTREAM_grad * W + dx = dnorm - dnorm. + + # tl.store(DX_ptr + hidden_size_offsets, dX, mask=mask) - tl.store(DW_ptr, dW) - tl.store(DB_ptr, dB) + # Need to ensure additions to the same channel are atomic + tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) + tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) def group_norm_forward(X, num_channels, num_groups, W, B, eps): shape = X.shape batch_size = shape[0] - print(X.stride(1)) # Reshape X so that the mean and std are computed across the groups - X = X.view(batch_size, num_groups, -1) + X = X.view(batch_size, num_groups, -1).contiguous() hidden_size = X.shape[-1] + print(f"Mean is {X.view(-1).mean()}") + print(f"RSTD is {1/torch.sqrt(torch.var(X.view(-1), unbiased=False) + 1e-6)}") BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) - Mean = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) - RSTD = torch.empty((batch_size, num_groups), dtype=X.dtype, device=X.device) + Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + # print(f"Init mean is {Mean}") + RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) + # print(f"Init RSTD is {RSTD}") + misc = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) - _group_norm_forward_kernel[(batch_size, num_channels)]( + _group_norm_forward_kernel[(batch_size, num_groups)]( Y, Y.stride(0), Y.stride(1), @@ -195,9 +203,13 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): RSTD.stride(1), hidden_size, eps, - BLOCK_SIZE=BLOCK_SIZE + BLOCK_SIZE=BLOCK_SIZE, + misc_ptr=misc ) - + print(X-Y) + print(f"Misc is {misc}") + print(f"After Init mean {Mean}") + print(f"After Init rstd {RSTD}") Y = Y.view(*shape) affine_shape = [1] * len(shape) affine_shape[1] = num_channels @@ -206,13 +218,20 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): + # print(f"Sum of upstream is : {dY.sum()}") shape = dY.shape batch_size = shape[0] hidden_size = dY.shape[-1] channels_per_group = num_channels // num_groups DX = torch.empty((batch_size, num_channels, hidden_size), dtype=X.dtype, device=X.device) - DW = torch.empty((batch_size, num_channels), dtype=W.dtype, device=W.device) - DB = torch.empty((batch_size, num_channels), dtype=B.dtype, device=B.device) + DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device) + DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device) + triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 + # print(Mean) + # print(RSTD) + # ans = (X - Mean[0]) * RSTD[0] + # ans = ans * dY + # print(f"Torch ans is {ans.sum(dim=-1)}") BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) _group_norm_backward_kernel[(batch_size, num_channels)]( X, @@ -225,16 +244,15 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): RSTD, DX, DW, - DW.stride(0), - DW.stride(1), DB, dY, hidden_size, channels_per_group, - BLOCK_SIZE=BLOCK_SIZE + BLOCK_SIZE=BLOCK_SIZE, + dtype=triton_dtype ) - print(DB) - return DX, DW.sum(dim=0), DB.sum(dim=0) + print(DW) + return DX, DW, DB class LigerGroupNormFunction(torch.autograd.Function): diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 54f52da44..5607acce1 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -1,4 +1,5 @@ import pytest +import random import torch from liger_kernel.ops.group_norm import LigerGroupNormFunction @@ -6,12 +7,18 @@ from liger_kernel.transformers.group_norm import LigerGroupNorm +random_batch_size = random.randint(1, 16) +random_num_groups = random.randint(1, 32) +random_num_channels = random_num_groups * random.randint(1, 16) +random_hidden_size = random.randint(1, 32) + @pytest.mark.parametrize( "batch_size, num_channels, num_groups, hidden_size", [ - (2, 1, 1, 4), - # (2, 4, 2, 128), + (1, 3, 1, 16), + # (2, 4, 2, 31), # (16, 12, 3, 4096), + # (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), ], ) @pytest.mark.parametrize( @@ -25,10 +32,11 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty torch.manual_seed(0) _tensor = torch.randn( - batch_size, num_channels, hidden_size, dtype=dtype, device="cuda", requires_grad=True + batch_size, num_channels, hidden_size, dtype=dtype, device="cuda" ) liger_x = _tensor.clone().detach().requires_grad_(True) + torch_x = _tensor.clone().detach().requires_grad_(True) liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).cuda() torch_ln = torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6).to(dtype).cuda() @@ -38,21 +46,21 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty torch_ln.bias.copy_(liger_ln.bias) liger_output = liger_ln(liger_x,) - torch_output = torch_ln(_tensor) + torch_output = torch_ln(torch_x) assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) - grad_output = torch.randn_like(_tensor) + grad_output = torch.randn_like(torch_x) liger_output.backward(grad_output, retain_graph=True) torch_output.backward(grad_output, retain_graph=True) # print(liger_x.grad) # print(_tensor.grad) # assert torch.allclose(liger_x.grad, _tensor.grad, atol=atol, rtol=rtol) - # # assert torch.allclose( - # # liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol - # # ) - print(grad_output) - print(grad_output.sum()) - print(liger_ln.bias.grad) - print(torch_ln.bias.grad) - assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) \ No newline at end of file + # print(f"Upstream Gradient: {grad_output}") + print(liger_x.shape) + print(liger_ln.weight.grad) + print(torch_ln.weight.grad) + assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) + assert torch.allclose( + liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol + ) From bfe29bf4e7f8a5d1e02a00cdf4439d91f0e059cd Mon Sep 17 00:00:00 2001 From: pramodith Date: Mon, 4 Nov 2024 19:36:57 +0000 Subject: [PATCH 08/18] cuanges --- src/liger_kernel/ops/group_norm.py | 37 ++++++++++++++++------------ test/transformers/test_group_norm.py | 25 +++++++++---------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index f7375b2f4..3f61e105d 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -104,6 +104,7 @@ def _group_norm_backward_kernel( UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) hidden_size: tl.constexpr, # hidden size channels_per_group: tl.constexpr, # number of groups in group norm + num_groups: tl.constexpr, BLOCK_SIZE: tl.constexpr, dtype: tl.constexpr, ): @@ -122,8 +123,8 @@ def _group_norm_backward_kernel( UPSTREAM_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride # Mean and rstd are the same shape so have the same strides - mean = tl.load(Mean_ptr) - rstd = tl.load(RSTD_ptr) + mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) W = tl.load(W_ptr + channel_idx) dW = 0.0 @@ -146,7 +147,7 @@ def _group_norm_backward_kernel( """ # dh_dx = -0.5 * (rstd**3) * dvar_dx - # c1 = 1 / hidden_size + c1 = 1 / hidden_size c2 = X - mean # c4 = tl.sum(c2) # dmean_dx = c1 @@ -159,12 +160,16 @@ def _group_norm_backward_kernel( # dX = W * UPSTREAM_grad * dY_dx - norm = c2 * rstd - dW += tl.sum(UPSTREAM_grad * norm) + c3 = c2 * rstd + dW += tl.sum(UPSTREAM_grad * c3) dB += tl.sum(UPSTREAM_grad) - dnorm = UPSTREAM_grad * W - dx = dnorm - dnorm. + x_hat = (X - mean) * rstd + wdy = W * UPSTREAM_grad + c1 = tl.sum(x_hat * wdy) / (hidden_size * num_groups) + c2 = tl.sum(wdy, axis=0) / (hidden_size * num_groups) + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + hidden_size_offsets, dx.to(dtype), mask=mask) # tl.store(DX_ptr + hidden_size_offsets, dX, mask=mask) @@ -178,15 +183,14 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): # Reshape X so that the mean and std are computed across the groups X = X.view(batch_size, num_groups, -1).contiguous() hidden_size = X.shape[-1] - print(f"Mean is {X.view(-1).mean()}") - print(f"RSTD is {1/torch.sqrt(torch.var(X.view(-1), unbiased=False) + 1e-6)}") + # print(f"Mean is {X.view(-1).mean()}") + # print(f"RSTD is {1/torch.sqrt(torch.var(X.view(-1), unbiased=False) + 1e-6)}") BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) # print(f"Init mean is {Mean}") RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) # print(f"Init RSTD is {RSTD}") - misc = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) _group_norm_forward_kernel[(batch_size, num_groups)]( Y, @@ -204,12 +208,9 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): hidden_size, eps, BLOCK_SIZE=BLOCK_SIZE, - misc_ptr=misc ) - print(X-Y) - print(f"Misc is {misc}") - print(f"After Init mean {Mean}") - print(f"After Init rstd {RSTD}") + # print(f"After Init mean {Mean}") + # print(f"After Init rstd {RSTD}") Y = Y.view(*shape) affine_shape = [1] * len(shape) affine_shape[1] = num_channels @@ -248,10 +249,14 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): dY, hidden_size, channels_per_group, + num_groups, BLOCK_SIZE=BLOCK_SIZE, dtype=triton_dtype ) - print(DW) + print(Mean) + a = (X.view(batch_size, num_groups, -1) - Mean.view(batch_size, num_groups, -1)) * RSTD.view(batch_size, num_groups, -1) + b = a * dY.view(batch_size, num_groups, -1) + print(f"Pure torch output: {b.view(batch_size, num_channels, -1).sum(-1)}") return DX, DW, DB diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 5607acce1..8013d2477 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -14,18 +14,18 @@ @pytest.mark.parametrize( "batch_size, num_channels, num_groups, hidden_size", - [ - (1, 3, 1, 16), - # (2, 4, 2, 31), - # (16, 12, 3, 4096), - # (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), + [ + (1, 3, 1, 15), + (1, 4, 2, 4), + (16, 12, 3, 4096), + (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), ], ) @pytest.mark.parametrize( "dtype, atol, rtol", [ (torch.float32, 1e-5, 1e-5), - # (torch.float16, 1e-3, 1e-3), + #(torch.float16, 1e-3, 1e-3), ], ) def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol): @@ -49,18 +49,17 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty torch_output = torch_ln(torch_x) assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) - grad_output = torch.randn_like(torch_x) liger_output.backward(grad_output, retain_graph=True) torch_output.backward(grad_output, retain_graph=True) # print(liger_x.grad) - # print(_tensor.grad) - # assert torch.allclose(liger_x.grad, _tensor.grad, atol=atol, rtol=rtol) + # print(torch_x.grad) + # assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) # print(f"Upstream Gradient: {grad_output}") - print(liger_x.shape) - print(liger_ln.weight.grad) - print(torch_ln.weight.grad) + # print(liger_x.shape) + print(f"Liger: grad {liger_ln.weight.grad}") + print(f"Torch: grad {torch_ln.weight.grad}") assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) assert torch.allclose( liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol - ) + ) \ No newline at end of file From a949c89a40071206ecb354a1b87165c8f23af4f2 Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Tue, 5 Nov 2024 00:09:16 +0000 Subject: [PATCH 09/18] progress --- src/liger_kernel/ops/group_norm.py | 16 +++++++--------- test/transformers/test_group_norm.py | 23 ++++++++++++----------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 3f61e105d..762486a55 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -104,13 +104,16 @@ def _group_norm_backward_kernel( UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) hidden_size: tl.constexpr, # hidden size channels_per_group: tl.constexpr, # number of groups in group norm - num_groups: tl.constexpr, BLOCK_SIZE: tl.constexpr, dtype: tl.constexpr, ): """ References: https://nn.labml.ai/normalization/group_norm/index.html + The backprop equations are the same for group_norm and layer_norm + the only difference here is that we load the W, Mean, Rstd corresponding to the + group we're computing gradients for and the mean and rstd are computed over n-channels + so the total number of elements we compute the mean over is num_channels_per_group * hidden_size """ batch_idx = tl.program_id(0) channel_idx = tl.program_id(1) @@ -133,7 +136,7 @@ def _group_norm_backward_kernel( for i in range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) mask = hidden_size_offsets < hidden_size - X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=mean) + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) UPSTREAM_grad = tl.load(UPSTREAM_ptr + hidden_size_offsets, mask=mask, other=0.0) """ Y = (X - mean) * rstd @@ -166,8 +169,8 @@ def _group_norm_backward_kernel( x_hat = (X - mean) * rstd wdy = W * UPSTREAM_grad - c1 = tl.sum(x_hat * wdy) / (hidden_size * num_groups) - c2 = tl.sum(wdy, axis=0) / (hidden_size * num_groups) + c1 = tl.sum(x_hat * wdy) / (hidden_size * channels_per_group) + c2 = tl.sum(wdy) / (hidden_size * channels_per_group) dx = (wdy - (x_hat * c1 + c2)) * rstd tl.store(DX_ptr + hidden_size_offsets, dx.to(dtype), mask=mask) @@ -249,14 +252,9 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): dY, hidden_size, channels_per_group, - num_groups, BLOCK_SIZE=BLOCK_SIZE, dtype=triton_dtype ) - print(Mean) - a = (X.view(batch_size, num_groups, -1) - Mean.view(batch_size, num_groups, -1)) * RSTD.view(batch_size, num_groups, -1) - b = a * dY.view(batch_size, num_groups, -1) - print(f"Pure torch output: {b.view(batch_size, num_channels, -1).sum(-1)}") return DX, DW, DB diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 8013d2477..06eb5bd14 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -15,10 +15,10 @@ @pytest.mark.parametrize( "batch_size, num_channels, num_groups, hidden_size", [ - (1, 3, 1, 15), - (1, 4, 2, 4), - (16, 12, 3, 4096), - (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), + (1, 2, 1, 3), + # (1, 4, 2, 4), + # (16, 12, 3, 4096), + # (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), ], ) @pytest.mark.parametrize( @@ -52,14 +52,15 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty grad_output = torch.randn_like(torch_x) liger_output.backward(grad_output, retain_graph=True) torch_output.backward(grad_output, retain_graph=True) - # print(liger_x.grad) - # print(torch_x.grad) - # assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) - # print(f"Upstream Gradient: {grad_output}") + print(f"Input grad liger: {liger_x.grad}") + + print(f"Torch grad :{torch_x.grad}") + assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) + print(f"Upstream Gradient: {grad_output}") # print(liger_x.shape) print(f"Liger: grad {liger_ln.weight.grad}") print(f"Torch: grad {torch_ln.weight.grad}") - assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol) - assert torch.allclose( + assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol), "Bias grads different" + assert torch.allclose( liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol - ) \ No newline at end of file + ), "Weight grads different" \ No newline at end of file From fd7009edbd2466a80a7376956b134767e5894873 Mon Sep 17 00:00:00 2001 From: pramodith Date: Tue, 5 Nov 2024 14:35:30 +0000 Subject: [PATCH 10/18] Fp32 all tests pass --- benchmark/scripts/benchmark_group_norm.py | 132 ++++++++++++++++++++++ src/liger_kernel/ops/group_norm.py | 110 +++++++++--------- test/transformers/test_group_norm.py | 23 ++-- 3 files changed, 191 insertions(+), 74 deletions(-) create mode 100644 benchmark/scripts/benchmark_group_norm.py diff --git a/benchmark/scripts/benchmark_group_norm.py b/benchmark/scripts/benchmark_group_norm.py new file mode 100644 index 000000000..7d27ca8da --- /dev/null +++ b/benchmark/scripts/benchmark_group_norm.py @@ -0,0 +1,132 @@ +import torch +import triton +from utils import ( + QUANTILES, + SingleBenchmarkRunInput, + SingleBenchmarkRunOutput, + _test_memory, + parse_benchmark_script_args, + run_benchmarks, +) + +from liger_kernel.transformers.group_norm import LigerGroupNorm + + +def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + C = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + H = extra_benchmark_config["H"] + channels_per_group = extra_benchmark_config["channels_per_group"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, C, H) + triton_ln = LigerGroupNorm(num_channels=C, num_groups=C//channels_per_group, eps=eps).to("cuda") + torch_ln = torch.nn.GroupNorm(num_groups=C//channels_per_group, num_channels=C, eps=eps).to("cuda") + + x = torch.randn(x_shape, dtype=dtype, device="cuda") + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + if provider == "liger": + return triton_ln(x) + if provider == "huggingface": + return torch_ln(x) + + if mode == "forward": + ms_50, ms_20, ms_80 = triton.testing.do_bench( + y_fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500 + ) + elif mode == "backward": + y = y_fwd() + ms_50, ms_20, ms_80 = triton.testing.do_bench( + lambda: y.backward(dy, retain_graph=True), + quantiles=QUANTILES, + grad_to_none=[x], + rep=500, + ) + elif mode == "full": + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + ms_50, ms_20, ms_80 = triton.testing.do_bench( + full, quantiles=QUANTILES, grad_to_none=[x], rep=500 + ) + + return SingleBenchmarkRunOutput( + y_20=ms_20, + y_50=ms_50, + y_80=ms_80, + ) + + +def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: + C = input.x + provider = input.kernel_provider + mode = input.kernel_operation_mode + extra_benchmark_config = input.extra_benchmark_config + M = extra_benchmark_config["M"] + H = extra_benchmark_config["H"] + channels_per_group = extra_benchmark_config["channels_per_group"] + eps = extra_benchmark_config["eps"] + dtype = extra_benchmark_config["dtype"] + + x_shape = (M, C, H) + triton_ln = LigerGroupNorm(num_channels=C, num_groups=C//channels_per_group, eps=eps).to("cuda") + torch_ln = torch.nn.GroupNorm(num_groups=C//channels_per_group, num_channels=C, eps=eps).to("cuda") + + x = torch.randn(x_shape, dtype=dtype, device="cuda") + dy = torch.randn_like(x) + x.requires_grad_(True) + + def y_fwd(): + if provider == "liger": + return triton_ln(x) + if provider == "huggingface": + return torch_ln(x) + + def full(): + y = y_fwd() + y.backward(dy, retain_graph=True) + + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) + return SingleBenchmarkRunOutput( + y_20=mem_20, + y_50=mem_50, + y_80=mem_80, + ) + + +if __name__ == "__main__": + args = parse_benchmark_script_args() + + common_configs = { + "kernel_name": "group_norm", + "x_name": "C", + "x_label": "num_channels", + "x_values": [2**i for i in range(10, 15)], + "kernel_providers": ["liger", "huggingface"], + "extra_benchmark_configs": [{"M": 4096, "H": 512, "channels_per_group": 4, "dtype": torch.float32, "eps": 1e-6}], + "overwrite": args.overwrite, + } + + run_benchmarks( + bench_test_fn=bench_speed_group_norm, + kernel_operation_modes=["forward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs + ) + run_benchmarks( + bench_test_fn=bench_memory_group_norm, + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs + ) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 762486a55..98a3adaa2 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -116,69 +116,65 @@ def _group_norm_backward_kernel( so the total number of elements we compute the mean over is num_channels_per_group * hidden_size """ batch_idx = tl.program_id(0) - channel_idx = tl.program_id(1) - - group_idx = channel_idx // channels_per_group + group_idx = tl.program_id(1) - # X_col_stide will correspond to the number of groups - X_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride - DX_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride - UPSTREAM_ptr += batch_idx * X_row_stride + channel_idx * X_col_stride + # Move the pointers to the correct batch + X_ptr += batch_idx * X_row_stride + DX_ptr += batch_idx * X_row_stride + UPSTREAM_ptr += batch_idx * X_row_stride # Mean and rstd are the same shape so have the same strides mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) - W = tl.load(W_ptr + channel_idx) - - dW = 0.0 - dB = 0.0 - for i in range(0, hidden_size, BLOCK_SIZE): - hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) - mask = hidden_size_offsets < hidden_size - X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) - UPSTREAM_grad = tl.load(UPSTREAM_ptr + hidden_size_offsets, mask=mask, other=0.0) - """ - Y = (X - mean) * rstd + c1 = 0.0 + c2 = 0.0 - h(x) = rstd = 1/(sqrt(var + eps)) - f(x) = x * h(x) = X * rstd - g(x) = - mean * h(x) = - mean * rstd + dW = tl.zeros((1), dtype=dtype) + dB = tl.zeros((1), dtype=dtype) - Y = f(x) + g(x) - dy_dx = df_dx + dg_dx - """ - # dh_dx = -0.5 * (rstd**3) * dvar_dx - c1 = 1 / hidden_size - c2 = X - mean - # c4 = tl.sum(c2) - # dmean_dx = c1 - # dvar_dx = c4 * c1 - # drstd_dx = - (rstd*rstd*rstd) * dvar_dx + # We need to compute the sum terms of the backprop equations across all channels in the group + for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + # Move the pointers to the correct channel - # df_dx = rstd + X * drstd_dx - # dg_dx = - (dmean_dx * rstd + mean * drstd_dx) - # dY_dx = df_dx + dg_dx + W = tl.load(W_ptr + channel_idx) + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + channel_idx * X_col_stride + hidden_size_offsets, mask=mask, other=0.0) + UPSTREAM_grad = tl.load(UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, mask=mask, other=0.0) + + x_hat = (X - mean) * rstd + dW = tl.sum(UPSTREAM_grad * x_hat) + dB = tl.sum(UPSTREAM_grad) - # dX = W * UPSTREAM_grad * dY_dx - - c3 = c2 * rstd - dW += tl.sum(UPSTREAM_grad * c3) - dB += tl.sum(UPSTREAM_grad) + wdy = W * UPSTREAM_grad + c1 += tl.sum(x_hat * wdy) + c2 += tl.sum(wdy) + + # Need to ensure additions to the same channel are atomic + tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) + tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) + + c1 = c1/(hidden_size * channels_per_group) + c2 = c2/(hidden_size * channels_per_group) + + for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + mask = hidden_size_offsets < hidden_size + X = tl.load(X_ptr + channel_idx * X_col_stride + hidden_size_offsets, mask=mask, other=0.0) + UPSTREAM_grad = tl.load(UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, mask=mask, other=0.0) + + x_hat = (X - mean) * rstd + wdy = W * UPSTREAM_grad + dx = (wdy - (x_hat * c1 + c2)) * rstd + tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask) - x_hat = (X - mean) * rstd - wdy = W * UPSTREAM_grad - c1 = tl.sum(x_hat * wdy) / (hidden_size * channels_per_group) - c2 = tl.sum(wdy) / (hidden_size * channels_per_group) - dx = (wdy - (x_hat * c1 + c2)) * rstd - tl.store(DX_ptr + hidden_size_offsets, dx.to(dtype), mask=mask) - # tl.store(DX_ptr + hidden_size_offsets, dX, mask=mask) - - # Need to ensure additions to the same channel are atomic - tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) - tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) def group_norm_forward(X, num_channels, num_groups, W, B, eps): shape = X.shape @@ -227,17 +223,14 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): batch_size = shape[0] hidden_size = dY.shape[-1] channels_per_group = num_channels // num_groups - DX = torch.empty((batch_size, num_channels, hidden_size), dtype=X.dtype, device=X.device) + dY = dY.view(batch_size, num_groups, -1) + DX = torch.empty((batch_size, num_groups, hidden_size * channels_per_group), dtype=X.dtype, device=X.device) DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device) DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device) triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 - # print(Mean) - # print(RSTD) - # ans = (X - Mean[0]) * RSTD[0] - # ans = ans * dY - # print(f"Torch ans is {ans.sum(dim=-1)}") + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) - _group_norm_backward_kernel[(batch_size, num_channels)]( + _group_norm_backward_kernel[(batch_size, num_groups)]( X, X.stride(0), X.stride(1), @@ -255,7 +248,8 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): BLOCK_SIZE=BLOCK_SIZE, dtype=triton_dtype ) - return DX, DW, DB + + return DX.view(*shape), DW, DB class LigerGroupNormFunction(torch.autograd.Function): diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 06eb5bd14..96bfd88b2 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -2,30 +2,27 @@ import random import torch -from liger_kernel.ops.group_norm import LigerGroupNormFunction -from liger_kernel.transformers.functional import liger_group_norm from liger_kernel.transformers.group_norm import LigerGroupNorm random_batch_size = random.randint(1, 16) random_num_groups = random.randint(1, 32) random_num_channels = random_num_groups * random.randint(1, 16) -random_hidden_size = random.randint(1, 32) +random_hidden_size = random.randint(1, 8192) @pytest.mark.parametrize( "batch_size, num_channels, num_groups, hidden_size", [ (1, 2, 1, 3), - # (1, 4, 2, 4), - # (16, 12, 3, 4096), - # (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), + (1, 4, 2, 4), + (16, 12, 3, 4096), + (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), ], ) @pytest.mark.parametrize( "dtype, atol, rtol", [ - (torch.float32, 1e-5, 1e-5), - #(torch.float16, 1e-3, 1e-3), + (torch.float32, 1e-4, 1e-4), ], ) def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol): @@ -52,15 +49,9 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty grad_output = torch.randn_like(torch_x) liger_output.backward(grad_output, retain_graph=True) torch_output.backward(grad_output, retain_graph=True) - print(f"Input grad liger: {liger_x.grad}") - - print(f"Torch grad :{torch_x.grad}") assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) - print(f"Upstream Gradient: {grad_output}") - # print(liger_x.shape) - print(f"Liger: grad {liger_ln.weight.grad}") - print(f"Torch: grad {torch_ln.weight.grad}") assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol), "Bias grads different" + close_mask = torch.isclose(liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol) assert torch.allclose( liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol - ), "Weight grads different" \ No newline at end of file + ), "Weight grads different" From 44a3f78ad92db6212e8523ffd5df5660c082716e Mon Sep 17 00:00:00 2001 From: pramodith Date: Tue, 5 Nov 2024 14:36:01 +0000 Subject: [PATCH 11/18] Remove line --- test/transformers/test_group_norm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 96bfd88b2..008d4d11d 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -51,7 +51,6 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty torch_output.backward(grad_output, retain_graph=True) assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol), "Bias grads different" - close_mask = torch.isclose(liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol) assert torch.allclose( liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol ), "Weight grads different" From 4728e13cbfaeab7e41d77f411ed260f561c43a66 Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:54:00 +0000 Subject: [PATCH 12/18] V1 --- benchmark/data/all_benchmark_data.csv | 84 +++++++++++++++++++++++ benchmark/scripts/benchmark_group_norm.py | 8 +-- src/liger_kernel/ops/group_norm.py | 54 +++++++-------- 3 files changed, 115 insertions(+), 31 deletions(-) diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index 32c8d01ab..d1b092a2b 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -505,3 +505,87 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859 fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,32,0.035840000957250595,0.03174399957060814,0.04505600035190582,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05222399905323982,0.053247999399900436,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,128,0.0870399996638298,0.0870399996638298,0.08806400001049042,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,256,0.1443839967250824,0.1443839967250824,0.1454080045223236,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,512,0.26521599292755127,0.26419198513031006,0.2662400007247925,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,1024,0.5140479803085327,0.5120000243186951,0.5160959959030151,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,2048,1.006592035293579,1.0035200119018555,1.0096640586853027,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,32,0.04198399931192398,0.04095999896526337,0.04198399931192398,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,64,0.06963200122117996,0.06860800087451935,0.06963200122117996,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,128,0.12492799758911133,0.12492799758911133,0.12492799758911133,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,256,0.2314240038394928,0.2303999960422516,0.2314240038394928,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,512,0.4505600035190582,0.4505600035190582,0.45260798931121826,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,1024,0.9011200070381165,0.8980479836463928,0.9031680226325989,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,2048,1.7950719594955444,1.7920000553131104,1.7960959672927856,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,32,0.28569599986076355,0.2815999984741211,0.29388800263404846,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,64,0.19763199985027313,0.19046400487422943,0.3768320083618164,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,128,0.2099200040102005,0.20787200331687927,0.21094399690628052,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,256,0.38092800974845886,0.37990400195121765,0.3829759955406189,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,512,0.7219200134277344,0.719871997833252,0.7229440212249756,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,1024,1.4049279689788818,1.4018559455871582,1.4090240001678467,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,2048,2.7458558082580566,2.743295907974243,2.748415946960449,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,32,0.12185599654912949,0.11878400295972824,0.13844487071037292,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,64,0.2099200040102005,0.2088959962129593,0.21094399690628052,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,128,0.33792001008987427,0.33689600229263306,0.33792001008987427,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,256,0.5908480286598206,0.5908480286598206,0.591871976852417,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,512,1.1110399961471558,1.106943964958191,1.1141120195388794,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,1024,2.160640001296997,2.1585919857025146,2.1780478954315186,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,2048,4.2690558433532715,4.2485761642456055,4.274585723876953,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,32,0.07884799689054489,0.07475200295448303,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,64,0.08499199897050858,0.08294399827718735,0.2908160090446472,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,128,0.13209599256515503,0.131071999669075,0.13312000036239624,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,256,0.24166400730609894,0.24063999950885773,0.24268800020217896,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,512,0.4556800127029419,0.4546560049057007,0.4567039906978607,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,1024,0.8919039964675903,0.8908799886703491,0.8939520120620728,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,2048,1.7643519639968872,1.7623039484024048,1.7663999795913696,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,32,0.08499199897050858,0.08396799862384796,0.08601599931716919,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,64,0.14643199741840363,0.14643199741840363,0.14745600521564484,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,128,0.2170879989862442,0.21606400609016418,0.2181120067834854,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,256,0.3614720106124878,0.3604480028152466,0.3624959886074066,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,512,0.652288019657135,0.6512640118598938,0.6563839912414551,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,1024,1.2584960460662842,1.2533760070800781,1.2615679502487183,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,2048,2.4688639640808105,2.465996742248535,2.4829952716827393,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 diff --git a/benchmark/scripts/benchmark_group_norm.py b/benchmark/scripts/benchmark_group_norm.py index 7d27ca8da..38946eb79 100644 --- a/benchmark/scripts/benchmark_group_norm.py +++ b/benchmark/scripts/benchmark_group_norm.py @@ -110,22 +110,22 @@ def full(): "kernel_name": "group_norm", "x_name": "C", "x_label": "num_channels", - "x_values": [2**i for i in range(10, 15)], + "x_values": [2**i for i in range(5, 12)], "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [{"M": 4096, "H": 512, "channels_per_group": 4, "dtype": torch.float32, "eps": 1e-6}], + "extra_benchmark_configs": [{"M": 128, "H": 512, "channels_per_group": 4, "dtype": torch.float32, "eps": 1e-6}], "overwrite": args.overwrite, } run_benchmarks( bench_test_fn=bench_speed_group_norm, - kernel_operation_modes=["forward", "full"], + kernel_operation_modes=["forward", "full", "backward"], metric_name="speed", metric_unit="ms", **common_configs ) run_benchmarks( bench_test_fn=bench_memory_group_norm, - kernel_operation_modes=["full"], + kernel_operation_modes=["full", "forward", "backward"], metric_name="memory", metric_unit="MB", **common_configs diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 98a3adaa2..4f0faac68 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -36,7 +36,10 @@ def _group_norm_forward_kernel( RSTD_ptr, # pointer to rstd, shape (n_rows, n_groups) RSTD_row_stride, # stride of each row in rstd RSTD_col_stride, # stride of each column in rstd + W_ptr, # pointer to W + B_ptr, # pointer to B hidden_size, + channels_per_group, eps, BLOCK_SIZE: tl.constexpr, ): @@ -74,15 +77,20 @@ def _group_norm_forward_kernel( # 1/std rstd = rsqrt(variance + eps) - # Normalize - for i in range(0, hidden_size, BLOCK_SIZE): - hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) - mask = hidden_size_offsets < hidden_size - X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) - Y = (X - m) * rstd - tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) - + hidden_size_per_channel = hidden_size//channels_per_group + for channel_idx in range(group_idx*channels_per_group, (group_idx+1)*channels_per_group): + W = tl.load(W_ptr + channel_idx) + B = tl.load(B_ptr + channel_idx) + for i in range(0, hidden_size_per_channel, BLOCK_SIZE): + hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + mask = hidden_size_offsets < hidden_size_per_channel + X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) + Y = (X - m) * rstd * W + B + tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) + + X_ptr += hidden_size_per_channel + Y_ptr += hidden_size_per_channel tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m) tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) @@ -130,14 +138,12 @@ def _group_norm_backward_kernel( c1 = 0.0 c2 = 0.0 - dW = tl.zeros((1), dtype=dtype) - dB = tl.zeros((1), dtype=dtype) - - # We need to compute the sum terms of the backprop equations across all channels in the group for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): # Move the pointers to the correct channel + dW = 0.0 + dB = 0.0 W = tl.load(W_ptr + channel_idx) for i in range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) @@ -146,8 +152,8 @@ def _group_norm_backward_kernel( UPSTREAM_grad = tl.load(UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, mask=mask, other=0.0) x_hat = (X - mean) * rstd - dW = tl.sum(UPSTREAM_grad * x_hat) - dB = tl.sum(UPSTREAM_grad) + dW += tl.sum(UPSTREAM_grad * x_hat) + dB += tl.sum(UPSTREAM_grad) wdy = W * UPSTREAM_grad c1 += tl.sum(x_hat * wdy) @@ -157,8 +163,9 @@ def _group_norm_backward_kernel( tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) - c1 = c1/(hidden_size * channels_per_group) - c2 = c2/(hidden_size * channels_per_group) + N = hidden_size * channels_per_group + c1 = c1/N + c2 = c2/N for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): # Move the pointers to the correct channel @@ -175,21 +182,17 @@ def _group_norm_backward_kernel( tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask) - def group_norm_forward(X, num_channels, num_groups, W, B, eps): shape = X.shape batch_size = shape[0] + channels_per_group = num_channels//num_groups # Reshape X so that the mean and std are computed across the groups X = X.view(batch_size, num_groups, -1).contiguous() hidden_size = X.shape[-1] - # print(f"Mean is {X.view(-1).mean()}") - # print(f"RSTD is {1/torch.sqrt(torch.var(X.view(-1), unbiased=False) + 1e-6)}") BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) - # print(f"Init mean is {Mean}") RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) - # print(f"Init RSTD is {RSTD}") _group_norm_forward_kernel[(batch_size, num_groups)]( Y, @@ -204,21 +207,18 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): RSTD, RSTD.stride(0), RSTD.stride(1), + W, + B, hidden_size, + channels_per_group, eps, BLOCK_SIZE=BLOCK_SIZE, ) - # print(f"After Init mean {Mean}") - # print(f"After Init rstd {RSTD}") Y = Y.view(*shape) - affine_shape = [1] * len(shape) - affine_shape[1] = num_channels - Y = Y * W.view(affine_shape) + B.view(affine_shape) return Y, X.view(*shape), Mean, RSTD, BLOCK_SIZE def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): - # print(f"Sum of upstream is : {dY.sum()}") shape = dY.shape batch_size = shape[0] hidden_size = dY.shape[-1] From 8196d259eb23399156960595e04ed479bfd79940 Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Tue, 5 Nov 2024 17:56:58 +0000 Subject: [PATCH 13/18] Style check --- benchmark/scripts/benchmark_group_norm.py | 27 +++- src/liger_kernel/ops/group_norm.py | 140 +++++++++++++------- src/liger_kernel/transformers/functional.py | 4 +- src/liger_kernel/transformers/group_norm.py | 15 ++- test/transformers/test_group_norm.py | 31 +++-- 5 files changed, 151 insertions(+), 66 deletions(-) diff --git a/benchmark/scripts/benchmark_group_norm.py b/benchmark/scripts/benchmark_group_norm.py index 38946eb79..595d379f8 100644 --- a/benchmark/scripts/benchmark_group_norm.py +++ b/benchmark/scripts/benchmark_group_norm.py @@ -24,8 +24,12 @@ def bench_speed_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRun dtype = extra_benchmark_config["dtype"] x_shape = (M, C, H) - triton_ln = LigerGroupNorm(num_channels=C, num_groups=C//channels_per_group, eps=eps).to("cuda") - torch_ln = torch.nn.GroupNorm(num_groups=C//channels_per_group, num_channels=C, eps=eps).to("cuda") + triton_ln = LigerGroupNorm( + num_channels=C, num_groups=C // channels_per_group, eps=eps + ).to("cuda") + torch_ln = torch.nn.GroupNorm( + num_groups=C // channels_per_group, num_channels=C, eps=eps + ).to("cuda") x = torch.randn(x_shape, dtype=dtype, device="cuda") dy = torch.randn_like(x) @@ -69,7 +73,6 @@ def full(): def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: C = input.x provider = input.kernel_provider - mode = input.kernel_operation_mode extra_benchmark_config = input.extra_benchmark_config M = extra_benchmark_config["M"] H = extra_benchmark_config["H"] @@ -78,8 +81,12 @@ def bench_memory_group_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu dtype = extra_benchmark_config["dtype"] x_shape = (M, C, H) - triton_ln = LigerGroupNorm(num_channels=C, num_groups=C//channels_per_group, eps=eps).to("cuda") - torch_ln = torch.nn.GroupNorm(num_groups=C//channels_per_group, num_channels=C, eps=eps).to("cuda") + triton_ln = LigerGroupNorm( + num_channels=C, num_groups=C // channels_per_group, eps=eps + ).to("cuda") + torch_ln = torch.nn.GroupNorm( + num_groups=C // channels_per_group, num_channels=C, eps=eps + ).to("cuda") x = torch.randn(x_shape, dtype=dtype, device="cuda") dy = torch.randn_like(x) @@ -112,7 +119,15 @@ def full(): "x_label": "num_channels", "x_values": [2**i for i in range(5, 12)], "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [{"M": 128, "H": 512, "channels_per_group": 4, "dtype": torch.float32, "eps": 1e-6}], + "extra_benchmark_configs": [ + { + "M": 128, + "H": 512, + "channels_per_group": 4, + "dtype": torch.float32, + "eps": 1e-6, + } + ], "overwrite": args.overwrite, } diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 4f0faac68..69b4a0480 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -1,14 +1,10 @@ -import math import operator import torch import triton import triton.language as tl -from liger_kernel.ops.utils import ( - compare_version, - ensure_contiguous, -) +from liger_kernel.ops.utils import compare_version, ensure_contiguous if compare_version("triton", operator.ge, "3.0.0"): try: @@ -22,11 +18,12 @@ MAX_FUSED_SIZE = 65536 + @triton.jit def _group_norm_forward_kernel( Y_ptr, # pointer to output, shape (n_rows, n_groups, hidden_size) Y_row_stride, # stride of each row in output - Y_col_stride, # stride of each column in output + Y_col_stride, # stride of each column in output X_ptr, # pointer to input, shape (n_rows, n_groups, hidden_size) X_row_stride, # stride of each row in input X_col_stride, # stride of each column in input @@ -52,7 +49,7 @@ def _group_norm_forward_kernel( X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride - + # Compute mean s = 0.0 for i in range(0, hidden_size, BLOCK_SIZE): @@ -60,9 +57,9 @@ def _group_norm_forward_kernel( mask = hidden_size_offsets < hidden_size X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) s += tl.sum(X) - - m = s/hidden_size - + + m = s / hidden_size + # Compute variance variance = 0.0 for i in range(0, hidden_size, BLOCK_SIZE): @@ -72,14 +69,16 @@ def _group_norm_forward_kernel( X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) diff = X - m variance += tl.sum(diff * diff) - + variance = variance / hidden_size # 1/std rstd = rsqrt(variance + eps) - + # Normalize - hidden_size_per_channel = hidden_size//channels_per_group - for channel_idx in range(group_idx*channels_per_group, (group_idx+1)*channels_per_group): + hidden_size_per_channel = hidden_size // channels_per_group + for channel_idx in range( + group_idx * channels_per_group, (group_idx + 1) * channels_per_group + ): W = tl.load(W_ptr + channel_idx) B = tl.load(B_ptr + channel_idx) for i in range(0, hidden_size_per_channel, BLOCK_SIZE): @@ -88,14 +87,14 @@ def _group_norm_forward_kernel( X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) Y = (X - m) * rstd * W + B tl.store(Y_ptr + hidden_size_offsets, Y, mask=mask) - + X_ptr += hidden_size_per_channel Y_ptr += hidden_size_per_channel - + tl.store(Mean_ptr + batch_idx * Mean_row_stride + group_idx * Mean_col_stride, m) tl.store(RSTD_ptr + batch_idx * RSTD_row_stride + group_idx * RSTD_col_stride, rstd) - + @triton.jit def _group_norm_backward_kernel( X_ptr, # pointer to input, shape (n_rows, n_channels, hidden_size) @@ -110,8 +109,8 @@ def _group_norm_backward_kernel( DW_ptr, # pointer to weights grad, shape (n_channels) DB_ptr, # pointer to bias grad, shape (n_channels) UPSTREAM_ptr, # pointer to output grad, shape (n_rows, n_channels, hidden_size) - hidden_size: tl.constexpr, # hidden size - channels_per_group: tl.constexpr, # number of groups in group norm + hidden_size: tl.constexpr, # hidden size + channels_per_group: tl.constexpr, # number of groups in group norm BLOCK_SIZE: tl.constexpr, dtype: tl.constexpr, ): @@ -132,14 +131,20 @@ def _group_norm_backward_kernel( UPSTREAM_ptr += batch_idx * X_row_stride # Mean and rstd are the same shape so have the same strides - mean = tl.load(Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) - rstd = tl.load(RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride) + mean = tl.load( + Mean_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride + ) + rstd = tl.load( + RSTD_ptr + batch_idx * Mean_ptr_row_stride + group_idx * Mean_ptr_col_stride + ) c1 = 0.0 c2 = 0.0 # We need to compute the sum terms of the backprop equations across all channels in the group - for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + for channel_idx in range( + group_idx * channels_per_group, (group_idx + 1) * channels_per_group + ): # Move the pointers to the correct channel dW = 0.0 @@ -148,9 +153,17 @@ def _group_norm_backward_kernel( for i in range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) mask = hidden_size_offsets < hidden_size - X = tl.load(X_ptr + channel_idx * X_col_stride + hidden_size_offsets, mask=mask, other=0.0) - UPSTREAM_grad = tl.load(UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, mask=mask, other=0.0) - + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + x_hat = (X - mean) * rstd dW += tl.sum(UPSTREAM_grad * x_hat) dB += tl.sum(UPSTREAM_grad) @@ -158,42 +171,56 @@ def _group_norm_backward_kernel( wdy = W * UPSTREAM_grad c1 += tl.sum(x_hat * wdy) c2 += tl.sum(wdy) - + # Need to ensure additions to the same channel are atomic tl.atomic_add(DW_ptr + channel_idx, dW.to(dtype)) tl.atomic_add(DB_ptr + channel_idx, dB.to(dtype)) - + N = hidden_size * channels_per_group - c1 = c1/N - c2 = c2/N - - for channel_idx in range(group_idx * channels_per_group, (group_idx + 1) * channels_per_group): + c1 = c1 / N + c2 = c2 / N + + for channel_idx in range( + group_idx * channels_per_group, (group_idx + 1) * channels_per_group + ): # Move the pointers to the correct channel W = tl.load(W_ptr + channel_idx) for i in range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) mask = hidden_size_offsets < hidden_size - X = tl.load(X_ptr + channel_idx * X_col_stride + hidden_size_offsets, mask=mask, other=0.0) - UPSTREAM_grad = tl.load(UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, mask=mask, other=0.0) - + X = tl.load( + X_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + UPSTREAM_grad = tl.load( + UPSTREAM_ptr + channel_idx * X_col_stride + hidden_size_offsets, + mask=mask, + other=0.0, + ) + x_hat = (X - mean) * rstd wdy = W * UPSTREAM_grad dx = (wdy - (x_hat * c1 + c2)) * rstd - tl.store(DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask) + tl.store( + DX_ptr + channel_idx * X_col_stride + hidden_size_offsets, dx, mask=mask + ) def group_norm_forward(X, num_channels, num_groups, W, B, eps): shape = X.shape batch_size = shape[0] - channels_per_group = num_channels//num_groups + channels_per_group = num_channels // num_groups # Reshape X so that the mean and std are computed across the groups X = X.view(batch_size, num_groups, -1).contiguous() hidden_size = X.shape[-1] BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(hidden_size)) - Y = torch.empty((batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device) + Y = torch.empty( + (batch_size, num_groups, hidden_size), dtype=X.dtype, device=X.device + ) Mean = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) RSTD = torch.zeros((batch_size, num_groups), dtype=X.dtype, device=X.device) - + _group_norm_forward_kernel[(batch_size, num_groups)]( Y, Y.stride(0), @@ -224,7 +251,11 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): hidden_size = dY.shape[-1] channels_per_group = num_channels // num_groups dY = dY.view(batch_size, num_groups, -1) - DX = torch.empty((batch_size, num_groups, hidden_size * channels_per_group), dtype=X.dtype, device=X.device) + DX = torch.empty( + (batch_size, num_groups, hidden_size * channels_per_group), + dtype=X.dtype, + device=X.device, + ) DW = torch.zeros((num_channels), dtype=W.dtype, device=W.device) DB = torch.zeros((num_channels), dtype=B.dtype, device=B.device) triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16 @@ -246,25 +277,44 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): hidden_size, channels_per_group, BLOCK_SIZE=BLOCK_SIZE, - dtype=triton_dtype + dtype=triton_dtype, ) - + return DX.view(*shape), DW, DB class LigerGroupNormFunction(torch.autograd.Function): @staticmethod @ensure_contiguous - def forward(ctx, X, affine_scaling_weight, affine_shifting_bias, num_channels, num_groups, eps): - Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward(X, num_channels, num_groups, affine_scaling_weight, affine_shifting_bias, eps) + def forward( + ctx, + X, + affine_scaling_weight, + affine_shifting_bias, + num_channels, + num_groups, + eps, + ): + Y, X, Mean, RSTD, BLOCK_SIZE = group_norm_forward( + X, + num_channels, + num_groups, + affine_scaling_weight, + affine_shifting_bias, + eps, + ) ctx.num_channels = num_channels ctx.num_groups = num_groups - ctx.save_for_backward(X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD) + ctx.save_for_backward( + X, affine_scaling_weight, affine_shifting_bias, Mean, RSTD + ) return Y @staticmethod @ensure_contiguous def backward(ctx, dY): X, W, B, Mean, RSTD = ctx.saved_tensors - DX, DW, DB = group_norm_backward(dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups) + DX, DW, DB = group_norm_backward( + dY, X, W, B, Mean, RSTD, ctx.num_channels, ctx.num_groups + ) return DX, DW, DB, None, None, None diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index 39672000f..292c0dba7 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -4,13 +4,13 @@ ) from liger_kernel.ops.fused_linear_jsd import LigerFusedLinearJSDFunction from liger_kernel.ops.geglu import LigerGELUMulFunction +from liger_kernel.ops.group_norm import LigerGroupNormFunction from liger_kernel.ops.jsd import LigerJSDFunction from liger_kernel.ops.kl_div import LigerKLDivLossFunction from liger_kernel.ops.layer_norm import LigerLayerNormFunction from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.ops.swiglu import LigerSiLUMulFunction -from liger_kernel.ops.group_norm import LigerGroupNormFunction liger_swiglu = LigerSiLUMulFunction.apply liger_cross_entropy = LigerCrossEntropyFunction.apply @@ -22,4 +22,4 @@ liger_kl_div = LigerKLDivLossFunction.apply liger_jsd = LigerJSDFunction.apply liger_fused_linear_jsd = LigerFusedLinearJSDFunction.apply -liger_group_norm = LigerGroupNormFunction.apply \ No newline at end of file +liger_group_norm = LigerGroupNormFunction.apply diff --git a/src/liger_kernel/transformers/group_norm.py b/src/liger_kernel/transformers/group_norm.py index 2b2631507..d0cc6799b 100644 --- a/src/liger_kernel/transformers/group_norm.py +++ b/src/liger_kernel/transformers/group_norm.py @@ -37,10 +37,19 @@ def __init__(self, num_channels, num_groups, eps=1e-6, bias=False, init_fn="ones def forward(self, hidden_states): # hidden_states: (batch_size, num_channels, *) - assert hidden_states.dim() >= 3, f"Input must have atleast 3 dimensions, got {hidden_states.dim()}" - assert hidden_states.size(1) == self.num_channels, f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}" + assert ( + hidden_states.dim() >= 3 + ), f"Input must have atleast 3 dimensions, got {hidden_states.dim()}" + assert ( + hidden_states.size(1) == self.num_channels + ), f"Input tensor must have {self.num_channels} channels, got {hidden_states.size(1)}" return LigerGroupNormFunction.apply( - hidden_states, self.weight, self.bias, self.num_channels, self.num_groups, self.variance_epsilon + hidden_states, + self.weight, + self.bias, + self.num_channels, + self.num_groups, + self.variance_epsilon, ) def extra_repr(self): diff --git a/test/transformers/test_group_norm.py b/test/transformers/test_group_norm.py index 008d4d11d..32419ed6a 100644 --- a/test/transformers/test_group_norm.py +++ b/test/transformers/test_group_norm.py @@ -1,19 +1,20 @@ -import pytest import random + +import pytest import torch from liger_kernel.transformers.group_norm import LigerGroupNorm - random_batch_size = random.randint(1, 16) random_num_groups = random.randint(1, 32) random_num_channels = random_num_groups * random.randint(1, 16) random_hidden_size = random.randint(1, 8192) + @pytest.mark.parametrize( "batch_size, num_channels, num_groups, hidden_size", - [ - (1, 2, 1, 3), + [ + (1, 1, 1, 3), (1, 4, 2, 4), (16, 12, 3, 4096), (random_batch_size, random_num_channels, random_num_groups, random_hidden_size), @@ -25,7 +26,9 @@ (torch.float32, 1e-4, 1e-4), ], ) -def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol): +def test_liger_group_norm( + batch_size, num_channels, num_groups, hidden_size, dtype, atol, rtol +): torch.manual_seed(0) _tensor = torch.randn( @@ -36,13 +39,19 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty torch_x = _tensor.clone().detach().requires_grad_(True) liger_ln = LigerGroupNorm(num_channels, num_groups, eps=1e-6).to(dtype).cuda() - torch_ln = torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6).to(dtype).cuda() - + torch_ln = ( + torch.nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, eps=1e-6) + .to(dtype) + .cuda() + ) + with torch.no_grad(): torch_ln.weight.copy_(liger_ln.weight) torch_ln.bias.copy_(liger_ln.bias) - liger_output = liger_ln(liger_x,) + liger_output = liger_ln( + liger_x, + ) torch_output = torch_ln(torch_x) assert torch.allclose(liger_output, torch_output, atol=atol, rtol=rtol) @@ -50,7 +59,9 @@ def test_liger_group_norm(batch_size, num_channels, num_groups, hidden_size, dty liger_output.backward(grad_output, retain_graph=True) torch_output.backward(grad_output, retain_graph=True) assert torch.allclose(liger_x.grad, torch_x.grad, atol=atol, rtol=rtol) - assert torch.allclose(liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol), "Bias grads different" - assert torch.allclose( + assert torch.allclose( + liger_ln.bias.grad, torch_ln.bias.grad, atol=atol, rtol=rtol + ), "Bias grads different" + assert torch.allclose( liger_ln.weight.grad, torch_ln.weight.grad, atol=atol, rtol=rtol ), "Weight grads different" From 0e3fb0372e7c693b5069fae397c0f31a1fda1848 Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:11:31 +0000 Subject: [PATCH 14/18] Compute mean and variance using the online algorithm --- src/liger_kernel/ops/group_norm.py | 43 ++++++++++++++---------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 69b4a0480..feba55e76 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -50,39 +50,36 @@ def _group_norm_forward_kernel( X_ptr += batch_idx * X_row_stride + group_idx * X_col_stride Y_ptr += batch_idx * Y_row_stride + group_idx * Y_col_stride - # Compute mean + block_range = tl.arange(0, BLOCK_SIZE) + + # Compute mean and variance using the online algorithm s = 0.0 - for i in range(0, hidden_size, BLOCK_SIZE): - hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + squared_sum = 0.0 + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range mask = hidden_size_offsets < hidden_size X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=0.0) s += tl.sum(X) + # X**2 + squared_sum += tl.sum(X * X) m = s / hidden_size + + # variance = E[X**2] - E[X]**2 + variance = squared_sum / hidden_size - m * m - # Compute variance - variance = 0.0 - for i in range(0, hidden_size, BLOCK_SIZE): - hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) - mask = hidden_size_offsets < hidden_size - # We need to mask out of index with mean to ensure that the variance remains unaffected - X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) - diff = X - m - variance += tl.sum(diff * diff) - - variance = variance / hidden_size # 1/std rstd = rsqrt(variance + eps) # Normalize hidden_size_per_channel = hidden_size // channels_per_group - for channel_idx in range( + for channel_idx in tl.range( group_idx * channels_per_group, (group_idx + 1) * channels_per_group ): W = tl.load(W_ptr + channel_idx) B = tl.load(B_ptr + channel_idx) for i in range(0, hidden_size_per_channel, BLOCK_SIZE): - hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + hidden_size_offsets = i + block_range mask = hidden_size_offsets < hidden_size_per_channel X = tl.load(X_ptr + hidden_size_offsets, mask=mask, other=m) Y = (X - m) * rstd * W + B @@ -140,18 +137,18 @@ def _group_norm_backward_kernel( c1 = 0.0 c2 = 0.0 + block_range = tl.arange(0, BLOCK_SIZE) # We need to compute the sum terms of the backprop equations across all channels in the group for channel_idx in range( group_idx * channels_per_group, (group_idx + 1) * channels_per_group ): - # Move the pointers to the correct channel - dW = 0.0 dB = 0.0 - W = tl.load(W_ptr + channel_idx) - for i in range(0, hidden_size, BLOCK_SIZE): - hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + # Move the pointers to the correct channel + W = tl.load(W_ptr + channel_idx) + for i in tl.range(0, hidden_size, BLOCK_SIZE): + hidden_size_offsets = i + block_range mask = hidden_size_offsets < hidden_size X = tl.load( X_ptr + channel_idx * X_col_stride + hidden_size_offsets, @@ -180,13 +177,13 @@ def _group_norm_backward_kernel( c1 = c1 / N c2 = c2 / N - for channel_idx in range( + for channel_idx in tl.range( group_idx * channels_per_group, (group_idx + 1) * channels_per_group ): # Move the pointers to the correct channel W = tl.load(W_ptr + channel_idx) for i in range(0, hidden_size, BLOCK_SIZE): - hidden_size_offsets = i + tl.arange(0, BLOCK_SIZE) + hidden_size_offsets = i + block_range mask = hidden_size_offsets < hidden_size X = tl.load( X_ptr + channel_idx * X_col_stride + hidden_size_offsets, From 8379d6bdedfbb206f1561b0e02b8e61e28ef4416 Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:38:57 +0000 Subject: [PATCH 15/18] New benchmark data --- benchmark/data/all_benchmark_data.csv | 198 +++++++++++++++----------- src/liger_kernel/ops/group_norm.py | 2 +- 2 files changed, 115 insertions(+), 85 deletions(-) diff --git a/benchmark/data/all_benchmark_data.csv b/benchmark/data/all_benchmark_data.csv index d1b092a2b..dfd31091c 100644 --- a/benchmark/data/all_benchmark_data.csv +++ b/benchmark/data/all_benchmark_data.csv @@ -505,87 +505,117 @@ fused_linear_jsd,torch,full,memory,MB,BT,B x T,1024,10609.005859375,10609.005859 fused_linear_jsd,torch,full,memory,MB,BT,B x T,2048,17146.009765625,17146.009765625,17146.009765625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 fused_linear_jsd,torch,full,memory,MB,BT,B x T,4096,30220.017578125,30220.017578125,30220.017578125,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 fused_linear_jsd,torch,full,memory,MB,BT,B x T,8192,56368.015625,56368.015625,56368.015625,"{""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA H100 80GB HBM3,2024-10-09 12:29:35,0.3.1 -group_norm,liger,forward,speed,ms,C,num_channels,32,0.035840000957250595,0.03174399957060814,0.04505600035190582,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 -group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05222399905323982,0.053247999399900436,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 -group_norm,liger,forward,speed,ms,C,num_channels,128,0.0870399996638298,0.0870399996638298,0.08806400001049042,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 -group_norm,liger,forward,speed,ms,C,num_channels,256,0.1443839967250824,0.1443839967250824,0.1454080045223236,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 -group_norm,liger,forward,speed,ms,C,num_channels,512,0.26521599292755127,0.26419198513031006,0.2662400007247925,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 -group_norm,liger,forward,speed,ms,C,num_channels,1024,0.5140479803085327,0.5120000243186951,0.5160959959030151,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 -group_norm,liger,forward,speed,ms,C,num_channels,2048,1.006592035293579,1.0035200119018555,1.0096640586853027,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:21,0.3.1 -group_norm,huggingface,forward,speed,ms,C,num_channels,32,0.04198399931192398,0.04095999896526337,0.04198399931192398,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 -group_norm,huggingface,forward,speed,ms,C,num_channels,64,0.06963200122117996,0.06860800087451935,0.06963200122117996,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 -group_norm,huggingface,forward,speed,ms,C,num_channels,128,0.12492799758911133,0.12492799758911133,0.12492799758911133,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 -group_norm,huggingface,forward,speed,ms,C,num_channels,256,0.2314240038394928,0.2303999960422516,0.2314240038394928,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 -group_norm,huggingface,forward,speed,ms,C,num_channels,512,0.4505600035190582,0.4505600035190582,0.45260798931121826,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 -group_norm,huggingface,forward,speed,ms,C,num_channels,1024,0.9011200070381165,0.8980479836463928,0.9031680226325989,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 -group_norm,huggingface,forward,speed,ms,C,num_channels,2048,1.7950719594955444,1.7920000553131104,1.7960959672927856,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:25,0.3.1 -group_norm,liger,full,speed,ms,C,num_channels,32,0.28569599986076355,0.2815999984741211,0.29388800263404846,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 -group_norm,liger,full,speed,ms,C,num_channels,64,0.19763199985027313,0.19046400487422943,0.3768320083618164,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 -group_norm,liger,full,speed,ms,C,num_channels,128,0.2099200040102005,0.20787200331687927,0.21094399690628052,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 -group_norm,liger,full,speed,ms,C,num_channels,256,0.38092800974845886,0.37990400195121765,0.3829759955406189,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 -group_norm,liger,full,speed,ms,C,num_channels,512,0.7219200134277344,0.719871997833252,0.7229440212249756,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 -group_norm,liger,full,speed,ms,C,num_channels,1024,1.4049279689788818,1.4018559455871582,1.4090240001678467,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 -group_norm,liger,full,speed,ms,C,num_channels,2048,2.7458558082580566,2.743295907974243,2.748415946960449,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:28,0.3.1 -group_norm,huggingface,full,speed,ms,C,num_channels,32,0.12185599654912949,0.11878400295972824,0.13844487071037292,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 -group_norm,huggingface,full,speed,ms,C,num_channels,64,0.2099200040102005,0.2088959962129593,0.21094399690628052,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 -group_norm,huggingface,full,speed,ms,C,num_channels,128,0.33792001008987427,0.33689600229263306,0.33792001008987427,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 -group_norm,huggingface,full,speed,ms,C,num_channels,256,0.5908480286598206,0.5908480286598206,0.591871976852417,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 -group_norm,huggingface,full,speed,ms,C,num_channels,512,1.1110399961471558,1.106943964958191,1.1141120195388794,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 -group_norm,huggingface,full,speed,ms,C,num_channels,1024,2.160640001296997,2.1585919857025146,2.1780478954315186,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 -group_norm,huggingface,full,speed,ms,C,num_channels,2048,4.2690558433532715,4.2485761642456055,4.274585723876953,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:31,0.3.1 -group_norm,liger,backward,speed,ms,C,num_channels,32,0.07884799689054489,0.07475200295448303,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 -group_norm,liger,backward,speed,ms,C,num_channels,64,0.08499199897050858,0.08294399827718735,0.2908160090446472,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 -group_norm,liger,backward,speed,ms,C,num_channels,128,0.13209599256515503,0.131071999669075,0.13312000036239624,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 -group_norm,liger,backward,speed,ms,C,num_channels,256,0.24166400730609894,0.24063999950885773,0.24268800020217896,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 -group_norm,liger,backward,speed,ms,C,num_channels,512,0.4556800127029419,0.4546560049057007,0.4567039906978607,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 -group_norm,liger,backward,speed,ms,C,num_channels,1024,0.8919039964675903,0.8908799886703491,0.8939520120620728,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 -group_norm,liger,backward,speed,ms,C,num_channels,2048,1.7643519639968872,1.7623039484024048,1.7663999795913696,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:34,0.3.1 -group_norm,huggingface,backward,speed,ms,C,num_channels,32,0.08499199897050858,0.08396799862384796,0.08601599931716919,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,backward,speed,ms,C,num_channels,64,0.14643199741840363,0.14643199741840363,0.14745600521564484,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,backward,speed,ms,C,num_channels,128,0.2170879989862442,0.21606400609016418,0.2181120067834854,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,backward,speed,ms,C,num_channels,256,0.3614720106124878,0.3604480028152466,0.3624959886074066,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,backward,speed,ms,C,num_channels,512,0.652288019657135,0.6512640118598938,0.6563839912414551,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,backward,speed,ms,C,num_channels,1024,1.2584960460662842,1.2533760070800781,1.2615679502487183,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,backward,speed,ms,C,num_channels,2048,2.4688639640808105,2.465996742248535,2.4829952716827393,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,full,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,full,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,full,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,full,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,full,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,full,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,full,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,full,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,full,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,full,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,full,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,full,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,full,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,full,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,forward,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,forward,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,forward,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,forward,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,forward,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,forward,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,forward,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,forward,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,forward,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,forward,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,forward,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,forward,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,forward,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,huggingface,forward,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:37,0.3.1 -group_norm,liger,backward,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,liger,backward,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,liger,backward,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,liger,backward,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,liger,backward,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,liger,backward,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,liger,backward,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,huggingface,backward,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,huggingface,backward,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,huggingface,backward,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,huggingface,backward,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,huggingface,backward,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,huggingface,backward,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 -group_norm,huggingface,backward,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-PCIE-40GB,2024-11-05 17:37:38,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,32,0.03481600061058998,0.03379200026392937,0.03993599861860275,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,64,0.05222399905323982,0.05119999870657921,0.05222399905323982,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,128,0.08499199897050858,0.08396799862384796,0.08499199897050858,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,256,0.1454080045223236,0.1443839967250824,0.14643199741840363,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,512,0.2611199915409088,0.2611199915409088,0.26214399933815,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,1024,0.49459201097488403,0.4925439953804016,0.4976640045642853,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 +group_norm,liger,forward,speed,ms,C,num_channels,2048,0.9789440035820007,0.9758719801902771,0.9820160269737244,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:35,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,32,0.04198399931192398,0.04198399931192398,0.043007999658584595,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:39,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,64,0.06963200122117996,0.06963200122117996,0.07065600156784058,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:39,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,128,0.12697599828243256,0.12595200538635254,0.12697599828243256,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:39,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,256,0.2314240038394928,0.2303999960422516,0.2314240038394928,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:39,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,512,0.4423680007457733,0.4423680007457733,0.4423680007457733,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:39,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,1024,0.8642560243606567,0.8632320165634155,0.8642560243606567,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:39,0.3.1 +group_norm,huggingface,forward,speed,ms,C,num_channels,2048,1.70905601978302,1.7080320119857788,1.7100800275802612,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:39,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,32,0.6625279784202576,0.49930238723754883,0.6850559711456299,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:43,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,64,0.6666240096092224,0.6604800224304199,0.6768640279769897,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:43,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,128,0.6615039706230164,0.6574079990386963,0.6696959733963013,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:43,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,256,0.6912000179290771,0.6850559711456299,0.6952959895133972,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:43,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,512,0.7188479900360107,0.7167999744415283,0.719871997833252,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:43,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,1024,1.4008320569992065,1.3987840414047241,1.4039039611816406,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:43,0.3.1 +group_norm,liger,full,speed,ms,C,num_channels,2048,2.7494399547576904,2.746367931365967,2.7535359859466553,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:43,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,32,0.3235839903354645,0.26521599292755127,0.32767999172210693,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:46,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,64,0.3246079981327057,0.32153600454330444,0.32972800731658936,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:46,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,128,0.33792001008987427,0.33689600229263306,0.3389439880847931,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:46,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,256,0.5877760052680969,0.5877760052680969,0.5888000130653381,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:46,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,512,1.0782719850540161,1.077247977256775,1.0792959928512573,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:46,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,1024,2.0797441005706787,2.0787200927734375,2.081792116165161,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:46,0.3.1 +group_norm,huggingface,full,speed,ms,C,num_channels,2048,4.068352222442627,4.067327976226807,4.069375991821289,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:46,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,32,0.29388800263404846,0.289792001247406,0.2979840040206909,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:50,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,64,0.29900801181793213,0.2949120104312897,0.30720001459121704,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:50,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,128,0.29286399483680725,0.289792001247406,0.2979840040206909,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:50,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,256,0.3184640109539032,0.31436800956726074,0.3235839903354645,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:50,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,512,0.45875200629234314,0.45772799849510193,0.45977601408958435,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:50,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,1024,0.8939520120620728,0.8919039964675903,0.894976019859314,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:50,0.3.1 +group_norm,liger,backward,speed,ms,C,num_channels,2048,1.7720320224761963,1.7702912092208862,1.773568034172058,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:50,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,32,0.1515520066022873,0.13516800105571747,0.15667200088500977,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,64,0.15360000729560852,0.15052799880504608,0.15667200088500977,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,128,0.2170879989862442,0.2170879989862442,0.2181120067834854,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,256,0.3614720106124878,0.3614720106124878,0.3624959886074066,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,512,0.6410239934921265,0.6399999856948853,0.6420480012893677,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,1024,1.222656011581421,1.2216320037841797,1.223680019378662,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,speed,ms,C,num_channels,2048,2.3654398918151855,2.3633921146392822,2.3664638996124268,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,full,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,full,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,forward,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,forward,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,32,40.01171875,40.01171875,40.01171875,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,64,80.01953125,80.01953125,80.01953125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,128,160.03515625,160.03515625,160.03515625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,256,320.0703125,320.0703125,320.0703125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,512,640.140625,640.140625,640.140625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,1024,1280.28125,1280.28125,1280.28125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,liger,backward,memory,MB,C,num_channels,2048,2560.5625,2560.5625,2560.5625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,32,40.06640625,40.06640625,40.06640625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,64,80.12890625,80.12890625,80.12890625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,128,160.25390625,160.25390625,160.25390625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,256,320.5078125,320.5078125,320.5078125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,512,641.015625,641.015625,641.015625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,1024,1282.03125,1282.03125,1282.03125,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +group_norm,huggingface,backward,memory,MB,C,num_channels,2048,2564.0625,2564.0625,2564.0625,"{""M"": 128, ""H"": 512, ""channels_per_group"": 4, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:20:53,0.3.1 +layer_norm,liger,forward,speed,ms,N,hidden size,1024,0.035840000957250595,0.03481600061058998,0.035840000957250595,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1 +layer_norm,liger,forward,speed,ms,N,hidden size,2048,0.05939200147986412,0.058368001133203506,0.060416001826524734,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1 +layer_norm,liger,forward,speed,ms,N,hidden size,4096,0.10751999914646149,0.10751999914646149,0.1085439994931221,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1 +layer_norm,liger,forward,speed,ms,N,hidden size,8192,0.20582400262355804,0.20479999482631683,0.20684799551963806,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1 +layer_norm,liger,forward,speed,ms,N,hidden size,16384,0.3993600010871887,0.3983359932899475,0.40140798687934875,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:51,0.3.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,1024,0.03788800165057182,0.03788800165057182,0.03891199827194214,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,2048,0.0655359998345375,0.0655359998345375,0.06656000018119812,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,4096,0.14745600521564484,0.14643199741840363,0.14847999811172485,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,8192,0.31334400177001953,0.3123199939727783,0.31436800956726074,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1 +layer_norm,huggingface,forward,speed,ms,N,hidden size,16384,0.6133760213851929,0.6123520135879517,0.6154239773750305,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:27:53,0.3.1 +layer_norm,liger,full,speed,ms,N,hidden size,1024,0.6860799789428711,0.6146048903465271,0.7049216032028198,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1 +layer_norm,liger,full,speed,ms,N,hidden size,2048,0.6789119839668274,0.6737920045852661,0.6912000179290771,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1 +layer_norm,liger,full,speed,ms,N,hidden size,4096,0.6686720252037048,0.6635519862174988,0.681984007358551,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1 +layer_norm,liger,full,speed,ms,N,hidden size,8192,0.6789119839668274,0.5908480286598206,0.6932479739189148,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1 +layer_norm,liger,full,speed,ms,N,hidden size,16384,6.071296215057373,5.331148624420166,6.08235502243042,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:02,0.3.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,1024,0.13312000036239624,0.13209599256515503,0.13312000036239624,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,2048,0.23244799673557281,0.2303999960422516,0.23347200453281403,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,4096,0.5242879986763,0.5232639908790588,0.5263360142707825,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,8192,1.0168319940567017,1.0147839784622192,1.018880009651184,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,huggingface,full,speed,ms,N,hidden size,16384,1.994752049446106,1.9916800260543823,1.9967999458312988,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,liger,full,memory,MB,N,hidden size,1024,80.90625,80.90625,80.90625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,liger,full,memory,MB,N,hidden size,2048,161.78125,161.78125,161.78125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,liger,full,memory,MB,N,hidden size,4096,323.53125,323.53125,323.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,liger,full,memory,MB,N,hidden size,8192,647.03125,647.03125,647.03125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,liger,full,memory,MB,N,hidden size,16384,1294.03125,1294.03125,1294.03125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:04,0.3.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,1024,80.0625,80.0625,80.0625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,2048,160.09375,160.09375,160.09375,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,4096,320.15625,320.15625,320.15625,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,8192,640.28125,640.28125,640.28125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 +layer_norm,huggingface,full,memory,MB,N,hidden size,16384,1280.53125,1280.53125,1280.53125,"{""M"": 4096, ""dtype"": ""torch.float32"", ""eps"": 1e-06}",NVIDIA A100-SXM4-40GB,2024-11-05 19:28:05,0.3.1 diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index feba55e76..edadf7132 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -66,7 +66,7 @@ def _group_norm_forward_kernel( m = s / hidden_size # variance = E[X**2] - E[X]**2 - variance = squared_sum / hidden_size - m * m + variance = (squared_sum / hidden_size) - (m * m) # 1/std rstd = rsqrt(variance + eps) From 0186bffb07afb69c79ae1a67d537514233d7eaad Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Tue, 5 Nov 2024 19:43:10 +0000 Subject: [PATCH 16/18] checkstyle --- dev/modal/tests.py | 1 + src/liger_kernel/ops/group_norm.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dev/modal/tests.py b/dev/modal/tests.py index 1b52b40db..880a2f299 100644 --- a/dev/modal/tests.py +++ b/dev/modal/tests.py @@ -17,6 +17,7 @@ @app.function(gpu="A10G", mounts=[repo], timeout=60 * 10) def liger_tests(): import subprocess + subprocess.run(["pip", "install", "-e", "."], check=True, cwd="/root/liger-kernel") subprocess.run(["make", "test"], check=True, cwd="/root/liger-kernel") subprocess.run(["make", "test-convergence"], check=True, cwd="/root/liger-kernel") diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index edadf7132..78c5f70bc 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -64,7 +64,7 @@ def _group_norm_forward_kernel( squared_sum += tl.sum(X * X) m = s / hidden_size - + # variance = E[X**2] - E[X]**2 variance = (squared_sum / hidden_size) - (m * m) @@ -146,7 +146,7 @@ def _group_norm_backward_kernel( dW = 0.0 dB = 0.0 # Move the pointers to the correct channel - W = tl.load(W_ptr + channel_idx) + W = tl.load(W_ptr + channel_idx) for i in tl.range(0, hidden_size, BLOCK_SIZE): hidden_size_offsets = i + block_range mask = hidden_size_offsets < hidden_size From 0f88ae34b0dae6b67bbb9f4d4c2a16d26d56254a Mon Sep 17 00:00:00 2001 From: pramodith <16939722+pramodith@users.noreply.github.com> Date: Wed, 6 Nov 2024 10:58:04 +0000 Subject: [PATCH 17/18] Add a few comments. --- src/liger_kernel/ops/group_norm.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index 78c5f70bc..fab92497b 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -35,8 +35,8 @@ def _group_norm_forward_kernel( RSTD_col_stride, # stride of each column in rstd W_ptr, # pointer to W B_ptr, # pointer to B - hidden_size, - channels_per_group, + hidden_size, # hidden size of X + channels_per_group, # the number of channels per group eps, BLOCK_SIZE: tl.constexpr, ): @@ -114,10 +114,14 @@ def _group_norm_backward_kernel( """ References: https://nn.labml.ai/normalization/group_norm/index.html + https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md + The backprop equations are the same for group_norm and layer_norm - the only difference here is that we load the W, Mean, Rstd corresponding to the + the only difference here is that we load the Mean, Rstd corresponding to the group we're computing gradients for and the mean and rstd are computed over n-channels so the total number of elements we compute the mean over is num_channels_per_group * hidden_size + + We also need to load the Weights corresponding to the current channel to compute the gradients. """ batch_idx = tl.program_id(0) group_idx = tl.program_id(1) @@ -238,8 +242,8 @@ def group_norm_forward(X, num_channels, num_groups, W, B, eps): eps, BLOCK_SIZE=BLOCK_SIZE, ) - Y = Y.view(*shape) - return Y, X.view(*shape), Mean, RSTD, BLOCK_SIZE + # Return tensors in the original shape + return Y.view(*shape), X.view(*shape), Mean, RSTD, BLOCK_SIZE def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): @@ -276,7 +280,8 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): BLOCK_SIZE=BLOCK_SIZE, dtype=triton_dtype, ) - + + # Return tensors in the original shape return DX.view(*shape), DW, DB From 138cbf81b55ea83f1f340ad3aa8044bc1473e482 Mon Sep 17 00:00:00 2001 From: Shao Tang Date: Thu, 7 Nov 2024 11:52:35 -0800 Subject: [PATCH 18/18] Update group_norm.py --- src/liger_kernel/ops/group_norm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/ops/group_norm.py b/src/liger_kernel/ops/group_norm.py index fab92497b..aeb4323f3 100644 --- a/src/liger_kernel/ops/group_norm.py +++ b/src/liger_kernel/ops/group_norm.py @@ -35,8 +35,8 @@ def _group_norm_forward_kernel( RSTD_col_stride, # stride of each column in rstd W_ptr, # pointer to W B_ptr, # pointer to B - hidden_size, # hidden size of X - channels_per_group, # the number of channels per group + hidden_size, # hidden size of X + channels_per_group, # the number of channels per group eps, BLOCK_SIZE: tl.constexpr, ): @@ -280,7 +280,7 @@ def group_norm_backward(dY, X, W, B, Mean, RSTD, num_channels, num_groups): BLOCK_SIZE=BLOCK_SIZE, dtype=triton_dtype, ) - + # Return tensors in the original shape return DX.view(*shape), DW, DB