Skip to content

Commit

Permalink
[AMD][TL] Introduce K Pack and a Conflict Free swizzling into Matrix …
Browse files Browse the repository at this point in the history
…Core (#248)

* Implemeng MFMA Make Swizzle Layout

* Implement Test

* format code

* test fix

* submodule update

* implement block level test

* lint fix
  • Loading branch information
LeiWang1999 authored Nov 27, 2024
1 parent b481405 commit 6f9c6ed
Show file tree
Hide file tree
Showing 20 changed files with 526 additions and 110 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from a12155 to e52254
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size,
make_swizzle_layout,
make_mma_swizzle_layout as make_swizzle_layout,
)

from bitblas.tl.mma_macro_generator import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size,
make_swizzle_layout,
make_mma_swizzle_layout as make_swizzle_layout,
)
from bitblas.ops.general_matmul.tilelang.dense.matmul_tensorcore import (
MatmulFineGrainScheduler,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, List, Literal
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401
)

from bitblas.tl.mma_macro_generator import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401
index_to_coordinates, # noqa: F401
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401
)
from .finegrained_primitive_tensorcore import MatmulDequantizeFineGrainedScheduler
from bitblas.tl.mma_macro_generator import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
make_mma_swizzle_layout as make_swizzle_layout, # noqa: F401
index_to_coordinates, # noqa: F401
)
from bitblas.base.arch import TileDevice
Expand Down
1 change: 0 additions & 1 deletion bitblas/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.

from .utils import (
get_swizzle_layout, # noqa: F401
mma_store_index_map, # noqa: F401
get_ldmatrix_offset, # noqa: F401
)
Expand Down
52 changes: 52 additions & 0 deletions bitblas/tl/mfma_layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from tvm import DataType
import tvm.tl.language as T
from tvm.runtime import convert


Expand Down Expand Up @@ -71,3 +74,52 @@ def thread_id_shared_access_64x4_to_16x16_layout_C_n_m(thread_id, local_id):
i = thread_id % 16
j = local_id + (thread_id // 16) * 4
return i, j


def thread_id_shared_access_64x8_to_16x32_layout_A(thread_id, local_id):
i = thread_id % 16
j = (thread_id // 16) * 8 + local_id
return i, j


def shared_16x32_to_local_64x8_layout_A(i, j):
thread_id = i + 16 * (j // 8)
local = (j % 8)
return thread_id, local


def thread_id_shared_access_64x8_to_16x32_layout_B(thread_id, local_id):
i = local_id + (thread_id // 16) * 8
j = thread_id % 16
return i, j


def shared_16x32_to_local_64x8_layout_B(i, j):
thread_id = j + (i // 8) * 16
local = (i % 8)
return thread_id, local


def make_mfma_swizzle_layout(shared_buf, vecSize=8):
dtype = shared_buf.dtype
shape = shared_buf.shape

numBanks = 32
bankBitWidth = 32
SIMDWidth = 16

innerDimLength = shape[-1]
typeWidthInBit = DataType(dtype).bits

elemsPerOneBanksRow = (numBanks * bankBitWidth) // typeWidthInBit
perPhase = max(1, elemsPerOneBanksRow // innerDimLength)
maxPhase = min(SIMDWidth // perPhase, innerDimLength // vecSize)

def transform(row, col):
phase = (row // perPhase) % maxPhase
colOffSwizzled = ((col // vecSize) ^ phase) * vecSize
colOffOrdered = col % vecSize
colOff = colOffSwizzled + colOffOrdered
return row, colOff

return T.Layout(shape, transform)
85 changes: 61 additions & 24 deletions bitblas/tl/mfma_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from tvm import DataType
from tvm.tir import PrimExpr
from tvm.runtime import convert
from typing import Optional
from .utils import (
mfma_store_index_map,)

Expand All @@ -30,22 +31,29 @@ class MatrixCoreIntrinEmitter(object):
"e5m2_float8": "e5m2",
}

# k_pack represents the number of elements in a vectorized instruction
# Detail information can be found in the triton documentation
# https://github.com/triton-lang/triton/blob/433037206d8870f0b82a3cd669097001084a29ed/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp#L419
k_pack = 1
# Represent the thread binding in the form of (tx, warp_n, warp_m)
is_m_first = False

def __init__(
self,
a_dtype="float16",
b_dtype="float16",
accum_dtype="float16",
a_transposed=False,
b_transposed=False,
block_row_warps=2,
block_col_warps=2,
warp_row_tiles=8,
warp_col_tiles=8,
chunk=16,
reduce_k=1,
num_elems_per_byte=1,
a_dtype: str = "float16",
b_dtype: str = "float16",
accum_dtype: str = "float16",
a_transposed: bool = False,
b_transposed: bool = False,
block_row_warps: int = 2,
block_col_warps: int = 2,
warp_row_tiles: int = 8,
warp_col_tiles: int = 8,
chunk: int = 16,
reduce_k: int = 1,
num_elems_per_byte: int = 1,
k_pack: Optional[int] = None,
is_m_first: Optional[bool] = False,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
Expand All @@ -63,6 +71,9 @@ def __init__(
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mfma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self._initialize_k_pack(k_pack)
self._initialize_is_m_first(is_m_first)

self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
Expand Down Expand Up @@ -113,19 +124,31 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_y = n_dim
self.micro_size_k = k_dim

def _initialize_k_pack(self, k_pack: Optional[int] = None):
if k_pack is not None:
self.k_pack = k_pack

def _initialize_is_m_first(self, is_m_first: Optional[bool] = False):
if is_m_first is not None:
self.is_m_first = is_m_first

def get_ldmatrix_index_map(self, is_b=False):
from .mfma_layout import (
shared_16x4_to_local_64x1_layout_A,
shared_4x16_to_local_64x1_layout_B,
shared_16x16_to_local_64x4_layout_A,
shared_16x16_to_local_64x4_layout_B,
shared_16x32_to_local_64x8_layout_A,
shared_16x32_to_local_64x8_layout_B,
thread_id_shared_access_64x1_to_16x4_layout_A,
thread_id_shared_access_64x1_to_4x16_layout_B,
thread_id_shared_access_64x4_to_16x16_layout_A,
thread_id_shared_access_64x4_to_16x16_layout_B,
thread_id_shared_access_64x8_to_16x32_layout_A,
thread_id_shared_access_64x8_to_16x32_layout_B,
)

k_dim = self.k_dim
k_dim = self.k_dim * self.k_pack
transposed = self.a_transposed if not is_b else self.b_transposed
if k_dim == 4:
index_map = shared_16x4_to_local_64x1_layout_A
Expand All @@ -140,6 +163,13 @@ def get_ldmatrix_index_map(self, is_b=False):
if is_b:
index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B
reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B
elif k_dim == 32:
index_map = shared_16x32_to_local_64x8_layout_B if transposed else shared_16x32_to_local_64x8_layout_A
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_B if transposed else thread_id_shared_access_64x8_to_16x32_layout_A

if is_b:
index_map = shared_16x32_to_local_64x8_layout_A if transposed else shared_16x32_to_local_64x8_layout_B
reverse_index_map = thread_id_shared_access_64x8_to_16x32_layout_A if transposed else thread_id_shared_access_64x8_to_16x32_layout_B
else:
raise ValueError("k_dim must be 4 or 16 currently")

Expand Down Expand Up @@ -181,6 +211,7 @@ def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0):
micro_size_x = self.micro_size_x
micro_size_k = self.micro_size_k
local_size_a = self.local_size_a
k_pack = self.k_pack
is_transposed = self.a_transposed

_, reverse_index_map = self.get_ldmatrix_index_map(is_b=False)
Expand All @@ -196,18 +227,20 @@ def _warp_ldmatrix_a(
tx, _, warp_m = self.extract_thread_binding(thread_bindings)
if is_transposed:
for i in T.serial(warp_rows):
for local_id in T.vectorized(local_size_a):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (rk * chunk + ki * micro_size_k,
warp_m * warp_row_tiles + i * micro_size_x)
A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]
else:
for i in T.serial(warp_rows):
for local_id in T.vectorized(local_size_a):
for local_id in T.vectorized(k_pack * local_size_a):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (warp_m * warp_row_tiles + i * micro_size_x,
rk * chunk + ki * micro_size_k)
A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col]
A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l + row,
r + col]

return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk)

Expand All @@ -218,6 +251,7 @@ def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
micro_size_y = self.micro_size_y
micro_size_k = self.micro_size_k
local_size_b = self.local_size_b
k_pack = self.k_pack
is_transposed = self.b_transposed

_, reverse_index_map = self.get_ldmatrix_index_map(is_b=True)
Expand All @@ -234,22 +268,24 @@ def _warp_ldmatrix_b(

if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(local_size_b):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * micro_size_k,
)
B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(local_size_b):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * chunk + ki * micro_size_k,
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col]
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]

return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk)

Expand All @@ -259,6 +295,7 @@ def mfma(self, A_local_buf, B_local_buf, C_local_buf):
local_size_a = self.local_size_a
local_size_b = self.local_size_b
local_size_out = self.local_size_out
k_pack = self.k_pack
mfma_suffix = self.mfma_suffix
a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype
compute_a_dtype = a_dtype if local_size_a == 1 else f"{a_dtype}x{local_size_a}"
Expand All @@ -267,7 +304,7 @@ def mfma(self, A_local_buf, B_local_buf, C_local_buf):

@T.macro
def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(warp_rows, warp_cols):
for kp, i, j in T.grid(k_pack, warp_rows, warp_cols):
T.tvm_mfma(
mfma_suffix,
"row",
Expand All @@ -276,9 +313,9 @@ def _warp_mma(A_local_buf, B_local_buf, C_local_buf):
compute_b_dtype,
compute_out_dtype,
B_local_buf.data,
(j * local_size_b) // local_size_b,
((j * k_pack + kp) * local_size_b) // local_size_b,
A_local_buf.data,
(i * local_size_a) // local_size_a,
((i * k_pack + kp) * local_size_a) // local_size_a,
C_local_buf.data,
(i * warp_cols * local_size_out + j * local_size_out) // local_size_out,
dtype=compute_out_dtype,
Expand Down
Loading

0 comments on commit 6f9c6ed

Please sign in to comment.