Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD][TL] Introduce K Pack and a Conflict Free swizzling into Matrix Core #248

Merged
merged 7 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from a12155 to e52254
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size,
make_swizzle_layout,
make_mma_swizzle_layout as make_swizzle_layout,
)

from bitblas.tl.mma_macro_generator import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size,
make_swizzle_layout,
make_mma_swizzle_layout as make_swizzle_layout,
)
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
MatmulFineGrainScheduler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, List, Literal
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401
)

from bitblas.tl.mma_macro_generator import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401
index_to_coordinates, # noqa: F401
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401
)
from .finegrained_primitive_tensorcore import MatmulDequantizeFineGrainedScheduler
from bitblas.tl.mma_macro_generator import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401
index_to_coordinates, # noqa: F401
)
from bitblas.base.arch import TileDevice
Expand Down
1 change: 0 additions & 1 deletion bitblas/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.

from .utils import (
get_swizzle_layout, # noqa: F401
mma_store_index_map, # noqa: F401
get_ldmatrix_offset, # noqa: F401
)
Expand Down
52 changes: 52 additions & 0 deletions bitblas/tl/mfma_layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from tvm import DataType
import tvm.tl.language as T
from tvm.runtime import convert


Expand Down Expand Up @@ -71,3 +74,52 @@ def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id):
i = thread_id % 16
j = local_id + (thread_id // 16) * 4
return i, j


def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id):
i = thread_id % 16
j = (thread_id // 16) * 8 + local_id
return i, j


def shared_16x32_to_local_64x8_layout_A(i, j):
thread_id = i + 16 * (j // 8)
local = (j % 8)
return thread_id, local


def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id):
i = local_id + (thread_id // 16) * 8
j = thread_id % 16
return i, j


def shared_16x32_to_local_64x8_layout_B(i, j):
thread_id = j + (i // 8) * 16
local = (i % 8)
return thread_id, local


def make_mfma_swizzle_layout(shared_buf, vecSize=8):
dtype = shared_buf.dtype
shape = shared_buf.shape

numBanks = 32
bankBitWidth = 32
SIMDWidth = 16

innerDimLength = shape[-1]
typeWidthInBit = DataType(dtype).bits

elemsPerOneBanksRow = (numBanks * bankBitWidth) // typeWidthInBit
perPhase = max(1, elemsPerOneBanksRow // innerDimLength)
maxPhase = min(SIMDWidth // perPhase, innerDimLength // vecSize)

def transform(row, col):
phase = (row // perPhase) % maxPhase
colOffSwizzled = ((col // vecSize) ^ phase) * vecSize
colOffOrdered = col % vecSize
colOff = colOffSwizzled + colOffOrdered
return row, colOff

return T.Layout(shape, transform)
85 changes: 61 additions & 24 deletions bitblas/tl/mfma_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tvm import DataType
from tvm.tir import PrimExpr
from tvm.runtime import convert
from typing import Optional
from .utils import (
mfma_store_index_map,)

Expand All @@ -30,22 +31,29 @@ class MatrixCoreIntrinEmitter(object):
"e5m2_float8": "e5m2",
}

# k_pack represents the number of elements in a vectorized instruction
# Detail information can be found in the triton documentation
# https://github.com/triton-lang/triton/blob/433037206d8870f0b82a3cd669097001084a29ed/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp#L419
k_pack = 1
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first = False

def __init__(
self,
a_dtype="float16",
b_dtype="float16",
accum_dtype="float16",
a_transposed=False,
b_transposed=False,
block_row_warps=2,
block_col_warps=2,
warp_row_tiles=8,
warp_col_tiles=8,
chunk=16,
reduce_k=1,
num_elems_per_byte=1,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
k_pack: Optional[int] = None,
is_m_first: Optional[bool] = False,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
Expand All @@ -63,6 +71,9 @@ def __init__(
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mfma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_k_pack(k_pack)
self._initialize_is_m_first(is_m_first)

self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
Expand Down Expand Up @@ -113,19 +124,31 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_y = n_dim
self.micro_size_k = k_dim

def _initialize_k_pack(self, k_pack: Optional[int] = None):
if k_pack is not None:
self.k_pack = k_pack

def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
if is_m_first is not None:
self.is_m_first = is_m_first

def get_ldmatrix_index_map(self, is_b=False):
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
)

k_dim = self.k_dim
k_dim = self.k_dim * self.k_pack
transposed = self.a_transposed if not is_b else self.b_transposed
if k_dim == 4:
index_map = shared_16x4_to_local_64x1_layout_A
Expand All @@ -140,6 +163,13 @@ def get_ldmatrix_index_map(self, is_b=False):
if is_b:
index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B
elif k_dim == 32:
index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A

if is_b:
index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
else:
raise ValueError("k_dim must be 4 or 16 currently")

Expand Down Expand Up @@ -181,6 +211,7 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
k_pack = self.k_pack
is_transposed = self.a_transposed

_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
Expand All @@ -196,18 +227,20 @@ def _warp_ldmatrix_a(
tx, _, warp_m = self.extract_thread_binding(thread_bindings)
if is_transposed:
for i in T.serial(warp_rows):
for local_id in T.vectorized(local_size_a):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * micro_size_k,
warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
else:
for i in T.serial(warp_rows):
for local_id in T.vectorized(local_size_a):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k)
A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]

return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)

Expand All @@ -218,6 +251,7 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
k_pack = self.k_pack
is_transposed = self.b_transposed

_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
Expand All @@ -234,22 +268,24 @@ def _warp_ldmatrix_b(

if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(local_size_b):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * micro_size_k,
)
B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(local_size_b):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * chunk + ki * micro_size_k,
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]

return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)

Expand All @@ -259,6 +295,7 @@ def mfma(self, A_local_buf, B_local_buf, C_local_buf):
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
k_pack = self.k_pack
mfma_suffix = self.mfma_suffix
a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype
compute_a_dtype = a_dtype if local_size_a == 1 else f"{a_dtype}x{local_size_a}"
Expand All @@ -267,7 +304,7 @@ def mfma(self, A_local_buf, B_local_buf, C_local_buf):

@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
for kp, i, j in T.grid(k_pack, warp_rows, warp_cols):
T.tvm_mfma(
mfma_suffix,
"row",
Expand All @@ -276,9 +313,9 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
compute_b_dtype,
compute_out_dtype,
B_local_buf.data,
(j * local_size_b) // local_size_b,
((j * k_pack + kp) * local_size_b) // local_size_b,
A_local_buf.data,
(i * local_size_a) // local_size_a,
((i * k_pack + kp) * local_size_a) // local_size_a,
C_local_buf.data,
(i * warp_cols * local_size_out + j * local_size_out) // local_size_out,
dtype=compute_out_dtype,
Expand Down
Loading
Loading