Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT Adding experimental feature : Triton mm int8xint2 #195

Merged
merged 18 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions src/liger_kernel/ops/experimental/mm_int8int2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import torch

import triton
import triton.language as tl

def unpack_weights(packed: torch.Tensor, bits: int = 2) -> torch.Tensor:
values_per_item = 8 // bits
packed_shape = packed.shape

if len(packed_shape) == 1:
original_row_dim = packed_shape[0] * values_per_item
unpacked_shape = (original_row_dim,)
else:
original_row_dim = packed_shape[0] * values_per_item
unpacked_shape = (original_row_dim, *packed_shape[1:])

unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)

for i in range(values_per_item):
start = i * packed_shape[0]
end = start + packed_shape[0]
mask = (3 << (2 * i))
unpacked[start:end] = (packed & mask) >> (2 * i)

unpacked = unpacked.to(torch.int32) - 1
return unpacked

def pack_weights(intweights: torch.Tensor, bits: int = 2) -> torch.Tensor:
intweights += 1
original_shape = intweights.shape
values_per_item = 8 // bits
row_dim = (original_shape[0] + values_per_item - 1) // values_per_item

if len(original_shape) == 1:
packed_tensor_shape = (row_dim,)
else:
packed_tensor_shape = (row_dim, *original_shape[1:])

packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
unpacked = intweights.to(torch.uint8)

def lshift(t: torch.Tensor, bits: int):
return t << bits

it = min(values_per_item, (original_shape[0] // row_dim) + 1)
for i in range(it):
start = i * row_dim
end = min(start + row_dim, original_shape[0])
packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)

return packed


def get_cuda_autotune_config():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not be cuda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't understand the comment

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mean triton and cuda are different things. maybe replace cuda with triton

return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4}, num_stages=4,
num_warps=4)
]

@triton.autotune(
configs=get_cuda_autotune_config(),
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
Copy link
Collaborator

@ByronHsu ByronHsu Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a paper reference and more comments. it is a bit hard to understand the current code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will add some comments to explain the process when i have some time

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude / chatgpt can be helpful

a_ptr, b_ptr, c_ptr,
M, N, K: tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
# We want K / 4 to be divisible by BLOCK_SIZE_K so that the multiplication can be aligned
tl.static_assert(K % (4*BLOCK_SIZE_K) == 0, "K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K")

# determine the block id in the 1D grid
pid = tl.program_id(axis=0)
#
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)

for i in range(4) :
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K) ):
k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j

a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K , other=0)
mask = 3<<(2*i)
b = ((b_uint8 & mask) >> (2*i))

# We accumulate the tiles along the K dimension.
tensor_full = tl.full((1,), 1, dtype=tl.int8)

accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)

a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

c = accumulator

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)


def matmul(a, b):
assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
_, N = b.shape
# c is in int32 to avoid any overflows or underflows
c = torch.empty((M, N), device=a.device, dtype=torch.int32)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
return c



76 changes: 76 additions & 0 deletions test/transformers/test_mm_int8int2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import pytest
import torch

from liger_kernel.transformers.experimental.mm_int8int2 import matmul, unpack_weights, pack_weights


# input_features = size*4 when the weight matrix is unpacked
@pytest.mark.parametrize(
"size",
[
2048,
1024,
512,
],
)
@pytest.mark.parametrize(
"batch_size",
[
1,
2,
3,
8,
100
],
)
@pytest.mark.parametrize(
"seq_len",
[
1,
7,
16,
2048
],
)
@pytest.mark.parametrize(
"out_features",
[
1024,
2048,
4096,
10000,
],
)
@pytest.mark.parametrize(
"atol, rtol, device",
[
(1e-2, 1e-2, "cuda"),
],
)
def test_kernel_correctness(batch_size, seq_len, out_features, size, atol, rtol, device):
print(f"\nTesting kernel with size: {size}, atol: {atol}, rtol: {rtol}")

# Generate the random tensors
ht = torch.randint(-127, 127, (batch_size, seq_len, size * 4), device=device, dtype=torch.int8)
u = torch.randint(0, 255, (out_features, size), device=device, dtype=torch.uint8)

# Calculate dimensions
B, M, N = ht.size()

# Compute triton output
triton_output = matmul(ht.view(B * M, N), u.T.contiguous()).view(B, M, -1)

# Validate packing and unpacking of weights
assert (pack_weights(unpack_weights(u.T), 2) == u.T).all(), "Packed weights do not match original weights."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo we can separate the correctness of pack + unpack to another testing func

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay i will


# Unpack weights and compute torch output
unpacked = unpack_weights(u.T, bits=2).T
torch_output = torch.matmul(ht.to(torch.float32), unpacked.T.contiguous().to(torch.float32))

# Print the results (optional, can be commented out)
print("triton_output =", triton_output)
print("torch_output =", torch_output)

# Check if outputs are close within the given tolerances
assert torch.allclose(triton_output, torch_output.to(torch.int32), atol=atol, rtol=rtol), "Results differ"