-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT Adding experimental feature : Triton mm int8xint2 #195
Changes from 4 commits
f0729e7
74a892b
200cde9
66394df
cf5c3aa
96a3412
fe585cd
e344fcb
1c13431
bf2d4b9
ac40e3f
e91a8a2
308707b
8f2c494
2b2f9fd
f3d580a
ca8592b
fb0f93b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import torch | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor: | ||
values_per_item = 8 // bits | ||
packed_shape = packed.shape | ||
|
||
if len(packed_shape) == 1: | ||
original_row_dim = packed_shape[0] * values_per_item | ||
unpacked_shape = (original_row_dim,) | ||
else: | ||
original_row_dim = packed_shape[0] * values_per_item | ||
unpacked_shape = (original_row_dim, *packed_shape[1:]) | ||
|
||
unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8) | ||
|
||
for i in range(values_per_item): | ||
start = i * packed_shape[0] | ||
end = start + packed_shape[0] | ||
mask = (3 << (2 * i)) | ||
unpacked[start:end] = (packed & mask) >> (2 * i) | ||
|
||
unpacked = unpacked.to(torch.int32) - 1 | ||
return unpacked | ||
|
||
def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor: | ||
intweights += 1 | ||
original_shape = intweights.shape | ||
values_per_item = 8 // bits | ||
row_dim = (original_shape[0] + values_per_item - 1) // values_per_item | ||
|
||
if len(original_shape) == 1: | ||
packed_tensor_shape = (row_dim,) | ||
else: | ||
packed_tensor_shape = (row_dim, *original_shape[1:]) | ||
|
||
packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8) | ||
unpacked = intweights.to(torch.uint8) | ||
|
||
def lshift(t: torch.Tensor, bits: int): | ||
return t << bits | ||
|
||
it = min(values_per_item, (original_shape[0] // row_dim) + 1) | ||
for i in range(it): | ||
start = i * row_dim | ||
end = min(start + row_dim, original_shape[0]) | ||
packed[: (end - start)] |= lshift(unpacked[start:end], bits * i) | ||
|
||
return packed | ||
|
||
|
||
def get_cuda_autotune_config(): | ||
return [ | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, | ||
num_warps=8), | ||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, | ||
num_warps=8), | ||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, | ||
num_warps=8), | ||
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, | ||
num_warps=4), | ||
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4, | ||
num_warps=4) | ||
] | ||
|
||
@triton.autotune( | ||
configs=get_cuda_autotune_config(), | ||
key=['M', 'N', 'K'], | ||
) | ||
@triton.jit | ||
def matmul_kernel( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a paper reference and more comments. it is a bit hard to understand the current code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I will add some comments to explain the process when i have some time There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. claude / chatgpt can be helpful |
||
a_ptr, b_ptr, c_ptr, | ||
M, N, K, | ||
stride_am, stride_ak, | ||
stride_bk, stride_bn, | ||
stride_cm, stride_cn, | ||
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | ||
GROUP_SIZE_M: tl.constexpr, | ||
): | ||
# Only triggered when TRITON_DEBUG is set to 1 => ex : TRITON_DEBUG=1 python scritp.py | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's a device_assert, it only works when running with TRITON_DEBUG not set to 0, and it ensure that the K is a multiple of BLOCK_SIZE * 4, which is the case of weight matrices, for alignment purposes. In the future we can find a way to make it more general There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does that mean it doesn't error out even if the alignment is incorrect if TRITON_DEBUG not enabled? Wondering if we can use https://triton-lang.org/main/python-api/generated/triton.language.static_assert.html#triton.language.static_assert There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. or can we just uplift the assertion before the kernel launch (line 158) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now it's working with static_assert, I just had to specify that K is tl.constexpr |
||
# We want K / 4 to be divisible by BLOCK_SIZE_K so that the multiplication can be aligned | ||
tl.device_assert(K % (4*BLOCK_SIZE_K) == 0, "K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K") | ||
|
||
pid = tl.program_id(axis=0) | ||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | ||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
num_pid_in_group = GROUP_SIZE_M * num_pid_n | ||
group_id = pid // num_pid_in_group | ||
first_pid_m = group_id * GROUP_SIZE_M | ||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | ||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | ||
pid_n = (pid % num_pid_in_group) // group_size_m | ||
|
||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | ||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | ||
offs_k = tl.arange(0, BLOCK_SIZE_K) | ||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | ||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | ||
|
||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) | ||
|
||
for i in range(4) : | ||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | ||
for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K) ): | ||
k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j | ||
|
||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0) | ||
b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K , other=0) | ||
mask = 3<<(2*i) | ||
b = ((b_uint8 & mask) >> (2*i)) | ||
|
||
# We accumulate the tiles along the K dimension. | ||
tensor_full = tl.full((1,), 1, dtype=tl.int8) | ||
|
||
accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32) | ||
|
||
a_ptrs += BLOCK_SIZE_K * stride_ak | ||
b_ptrs += BLOCK_SIZE_K * stride_bk | ||
|
||
c = accumulator | ||
|
||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | ||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | ||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | ||
tl.store(c_ptrs, c, mask=c_mask) | ||
|
||
|
||
def matmul(a, b): | ||
assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed" | ||
assert a.is_contiguous(), "Matrix A must be contiguous" | ||
M, K = a.shape | ||
_, N = b.shape | ||
# c is in int32 to avoid any overflows or underflows | ||
c = torch.empty((M, N), device=a.device, dtype=torch.int32) | ||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) | ||
matmul_kernel[grid]( | ||
a, b, c, | ||
M, N, K, | ||
a.stride(0), a.stride(1), | ||
b.stride(0), b.stride(1), | ||
c.stride(0), c.stride(1), | ||
) | ||
return c | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import pytest | ||
import torch | ||
|
||
from liger_kernel.transformers.experimental.mm_int8int2 import matmul, unpack_weights, pack_weights | ||
|
||
|
||
# input_features = size*4 when the weight matrix is unpacked | ||
@pytest.mark.parametrize( | ||
"size", | ||
[ | ||
2048, | ||
1024, | ||
512, | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"batch_size", | ||
[ | ||
1, | ||
2, | ||
3, | ||
8, | ||
100 | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"seq_len", | ||
[ | ||
1, | ||
7, | ||
16, | ||
2048 | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"out_features", | ||
[ | ||
1024, | ||
2048, | ||
4096, | ||
10000, | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"atol, rtol, device", | ||
[ | ||
(1e-2, 1e-2, "cuda"), | ||
], | ||
) | ||
def test_kernel_correctness(batch_size, seq_len, out_features, size, atol, rtol, device): | ||
print(f"\nTesting kernel with size: {size}, atol: {atol}, rtol: {rtol}") | ||
|
||
# Generate the random tensors | ||
ht = torch.randint(-127, 127, (batch_size, seq_len, size * 4), device=device, dtype=torch.int8) | ||
u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8) | ||
|
||
# Calculate dimensions | ||
B, M, N = ht.size() | ||
|
||
# Compute triton output | ||
triton_output = matmul(ht.view(B * M, N), u.T.contiguous()).view(B, M, -1) | ||
|
||
# Validate packing and unpacking of weights | ||
assert (pack_weights(unpack_weights(u.T), 2) == u.T).all(), "Packed weights do not match original weights." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imo we can separate the correctness of pack + unpack to another testing func There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay i will |
||
|
||
# Unpack weights and compute torch output | ||
unpacked = unpack_weights(u.T, bits=2).T | ||
torch_output = torch.matmul(ht.to(torch.float32), unpacked.T.contiguous().to(torch.float32)) | ||
|
||
# Print the results (optional, can be commented out) | ||
print("triton_output =", triton_output) | ||
print("torch_output =", torch_output) | ||
|
||
# Check if outputs are close within the given tolerances | ||
assert torch.allclose(triton_output, torch_output.to(torch.int32), atol=atol, rtol=rtol), "Results differ" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should not be cuda?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn't understand the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i mean triton and cuda are different things. maybe replace cuda with triton