-
Notifications
You must be signed in to change notification settings - Fork 39
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TL] Add TL Layout and Macro utils (#174)
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples * Enhance Swizzle * lint fix * test fix * lint fix * optimize layout * update tl utils. * macro optimization * test fix
- Loading branch information
1 parent
3aa9439
commit c15744e
Showing
6 changed files
with
344 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from .utils import ( | ||
get_swizzle_layout, # noqa: F401 | ||
mma_store_index_map, # noqa: F401 | ||
get_ldmatrix_offset, # noqa: F401 | ||
) | ||
|
||
from .macro_generator import TensorCorePTXMacroGenerator # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import tvm.tl.language as T | ||
|
||
from tvm import DataType | ||
from tvm.runtime import convert | ||
from .utils import ( | ||
mma_store_index_map, | ||
get_ldmatrix_offset, | ||
) | ||
|
||
lift = convert | ||
|
||
|
||
class TensorCorePTXMacroGenerator(object): | ||
""" | ||
To eliminate Python syntax within TIR Macro. | ||
""" | ||
|
||
M_DIM = 16 | ||
N_DIM = 16 | ||
WARP_SIZE = 32 | ||
dtype_abbrv = { | ||
"float16": "fp16", | ||
"bfloat16": "bf16", | ||
"float32": "fp32", | ||
"int8": "int8", | ||
"int32": "int32", | ||
"e4m3_float8": "e4m3", | ||
"e5m2_float8": "e5m2", | ||
} | ||
|
||
def __init__( | ||
self, | ||
a_dtype="float16", | ||
b_dtype="float16", | ||
accum_dtype="float16", | ||
a_transposed=False, | ||
b_transposed=False, | ||
block_row_warps=2, | ||
block_col_warps=2, | ||
warp_row_tiles=8, | ||
warp_col_tiles=8, | ||
chunk=16, | ||
threads=128, | ||
): | ||
self.a_dtype = a_dtype | ||
self.b_dtype = b_dtype | ||
self.accum_dtype = accum_dtype | ||
self.a_transposed = a_transposed | ||
self.b_transposed = b_transposed | ||
# Hint Information | ||
self.block_row_warps = block_row_warps | ||
self.block_col_warps = block_col_warps | ||
self.warp_row_tiles = warp_row_tiles | ||
self.warp_col_tiles = warp_col_tiles | ||
self.chunk = chunk | ||
self._initialize_k_dim(a_dtype) | ||
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) | ||
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) | ||
self._initialize_mma_prefix(self.k_dim) | ||
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) | ||
self.warp_rows = warp_row_tiles // self.micro_size_x | ||
self.warp_cols = warp_col_tiles // self.micro_size_y | ||
self._initialize_thread_axis(threads, self.WARP_SIZE, block_row_warps, block_col_warps) | ||
|
||
def _initialize_k_dim(self, a_dtype="float16"): | ||
self.k_dim = 256 // DataType(a_dtype).bits | ||
|
||
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): | ||
self.local_size_a = (m_dim * k_dim) // warp_size | ||
self.local_size_b = (n_dim * k_dim) // warp_size | ||
self.local_size_out = (m_dim * n_dim) // warp_size | ||
|
||
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): | ||
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] | ||
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] | ||
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] | ||
|
||
def _initialize_mma_prefix(self, k_dim=16): | ||
if k_dim == 16: | ||
self.mma_prefix = "m16n8k16" | ||
elif k_dim == 32: | ||
self.mma_prefix = "m16n8k32" | ||
else: | ||
raise ValueError("Unsupported k_dim") | ||
|
||
def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): | ||
self.micro_size_x = m_dim | ||
self.micro_size_y = n_dim | ||
self.micro_size_k = k_dim | ||
|
||
def _initialize_thread_axis(self, | ||
threads=128, | ||
warp_size=32, | ||
block_row_warps=2, | ||
block_col_warps=2): | ||
self.threads = threads | ||
# thread_bindings = T.env_thread("threadIdx.x") | ||
# self.tx = thread_bindings % warp_size | ||
# self.ty = (thread_bindings // warp_size) % block_row_warps | ||
# self.tz = thread_bindings // (warp_size * block_row_warps) | ||
|
||
@staticmethod | ||
@T.macro | ||
def MMA(inst, A_local_buf, B_local_buf, C_local_buf): | ||
for i, j in T.grid(inst.warp_rows, inst.warp_cols): | ||
T.ptx_mma( | ||
inst.accum_dtype, | ||
"m16n8k16", | ||
"row", | ||
"col", | ||
inst.a_dtype_abbrv, | ||
inst.b_dtype_abbrv, | ||
inst.accum_dtype_abbrv, | ||
A_local_buf.data, | ||
i * inst.local_size_a, | ||
B_local_buf.data, | ||
j * inst.local_size_b, | ||
C_local_buf.data, | ||
i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out, | ||
T.bool(False), | ||
) | ||
|
||
T.ptx_mma( | ||
inst.accum_dtype, | ||
"m16n8k16", | ||
"row", | ||
"col", | ||
inst.a_dtype_abbrv, | ||
inst.b_dtype_abbrv, | ||
inst.accum_dtype_abbrv, | ||
A_local_buf.data, | ||
i * inst.local_size_a, | ||
B_local_buf.data, | ||
j * inst.local_size_b + lift(inst.local_size_b) // 2, | ||
C_local_buf.data, | ||
i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out + | ||
lift(inst.local_size_out) // 2, | ||
T.bool(False), | ||
) | ||
|
||
@staticmethod | ||
@T.macro | ||
def LDMATRIX_A( | ||
inst, | ||
A_local_buf, | ||
A_shared_buf, | ||
ki, | ||
thread_bindings, | ||
): | ||
stride = inst.chunk | ||
tx = thread_bindings % inst.WARP_SIZE | ||
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps | ||
# self.ty = (thread_bindings // warp_size) % block_row_warps | ||
# self.tz = thread_bindings // (warp_size * block_row_warps) | ||
for i in T.serial(inst.warp_rows): | ||
T.ptx_ldmatrix( | ||
"float16", | ||
T.bool(False), | ||
4, | ||
".b16", | ||
A_local_buf.data, | ||
i * inst.local_size_a, | ||
T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x, | ||
ki * inst.micro_size_k,]), | ||
get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, False), | ||
) | ||
|
||
@staticmethod | ||
@T.macro | ||
def LDMATRIX_B( | ||
inst, | ||
B_local_buf, | ||
B_shared_buf, | ||
ki, | ||
thread_bindings, | ||
): | ||
stride = inst.chunk | ||
tx = thread_bindings % inst.WARP_SIZE | ||
tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps) | ||
for j in T.serial(inst.warp_cols): | ||
T.ptx_ldmatrix( | ||
"float16", | ||
T.bool(False), # TODO(lei): should be optimized | ||
4, | ||
".b16", | ||
B_local_buf.data, | ||
j * inst.local_size_b, | ||
T.address_of(B_shared_buf[tz * inst.warp_col_tiles + j * inst.micro_size_y, | ||
ki * inst.micro_size_k,]), | ||
get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, True), | ||
) | ||
|
||
# STS | ||
# MMA Store must be in simulated instead of TVM Intrins | ||
# As TVM Intrins is like a hack that the threadIdx.x should be always | ||
# equal to the warp_size | ||
@staticmethod | ||
@T.macro | ||
def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings): | ||
tx = thread_bindings % inst.WARP_SIZE | ||
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps | ||
tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps) | ||
for i, j in T.grid(inst.warp_rows, inst.warp_cols): | ||
for local_id in T.serial(inst.local_size_out): | ||
row, col = T.meta_var(mma_store_index_map(tx, local_id)) | ||
C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row, | ||
col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) + | ||
j * inst.local_size_out + local_id] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from tvm import arith | ||
from tvm import DataType | ||
from typing import Union, Literal | ||
|
||
|
||
def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): | ||
ana = arith.Analyzer() | ||
BANK_SIZE_BYTES = 128 | ||
if isinstance(dtype, str): | ||
dtype = DataType(dtype) | ||
col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % ( | ||
BANK_SIZE_BYTES // dtype.bits) | ||
# use transaction bits to support diverse dtype. | ||
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits | ||
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits | ||
coalescent_bits = dtype.bits * row_size | ||
# permutation on 4 banks, each bank has 32 bits | ||
bank_elems = BANK_SIZE_BYTES // dtype.bits | ||
new_col_idx_outer = None | ||
print(f"coalescent_bits: {coalescent_bits}") | ||
if coalescent_bits % 1024 == 0: | ||
# Use 8 * 8 permuted layout | ||
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read | ||
# Every row below corresponds to 32 banks | ||
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7 | ||
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6 | ||
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5 | ||
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4 | ||
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3 | ||
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2 | ||
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1 | ||
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0 | ||
row_idx_sub = row_idx % bank_elems | ||
new_col_idx_outer = col_idx_outer ^ row_idx_sub | ||
else: | ||
assert coalescent_bits % 512 == 0 | ||
# Use 8 * 4 permuted layout | ||
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read | ||
# Every row below corresponds to 16 banks | ||
# 0 1 2 3 ==> 0 1 2 3 | ||
# 0 1 2 3 ==> 0 1 2 3 | ||
# 0 1 2 3 ==> 1 0 3 2 | ||
# 0 1 2 3 ==> 1 0 3 2 | ||
# 0 1 2 3 ==> 2 3 0 1 | ||
# 0 1 2 3 ==> 2 3 0 1 | ||
# 0 1 2 3 ==> 3 2 1 0 | ||
# 0 1 2 3 ==> 3 2 1 0 | ||
# View with 8 elements per row: | ||
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3 | ||
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2 | ||
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1 | ||
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0 | ||
row_idx_sub = row_idx % bank_elems | ||
# Interleave elems per byte | ||
interleave_elems = 32 // dtype.bits | ||
new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems) | ||
|
||
assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits" | ||
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner) | ||
|
||
|
||
def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): | ||
row = thread_id % 16 | ||
col = 8 * (thread_id // 16) + local_id % 8 | ||
return row, col | ||
|
||
|
||
def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): | ||
row = 8 * (thread_id // 16) + (thread_id % 8) | ||
col = 8 * ((thread_id % 16) // 8) + local_id % 8 | ||
return row, col | ||
|
||
|
||
def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): | ||
row = thread_id % 16 | ||
col = local_id + (thread_id // 16) * 16 | ||
return row, col | ||
|
||
|
||
def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): | ||
row = (thread_id // 16) * 8 + (thread_id % 8) | ||
col = local_id + 16 * ((thread_id % 16) // 8) | ||
return row, col | ||
|
||
|
||
def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id): | ||
row = 8 * (local_id % 4 // 2) + (thread_id // 4) | ||
col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2) | ||
return row, col | ||
|
||
|
||
def get_ldmatrix_offset( | ||
matrix: Literal["A", "B"], | ||
row_idx, | ||
col_idx, | ||
stride, | ||
dtype: Literal["float16", "int8"] = "float16", | ||
transpose: bool = False, | ||
): | ||
assert matrix in ["A", "B"], "matrix should be either A or B" | ||
transform_func = ( | ||
ldmatrix_32x8_to_shared_16x16_layout | ||
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_b) | ||
transform_func_trans = ( | ||
ldmatrix_trans_32x8_to_shared_16x16_layout | ||
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_a) | ||
if matrix == "A": | ||
assert not transpose, "A matrix should not be transposed" | ||
new_row_idx, new_col_idx = transform_func(row_idx, col_idx) | ||
return new_row_idx * stride + new_col_idx | ||
else: | ||
new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx) | ||
return new_row_idx * stride + new_col_idx | ||
|
||
|
||
def mma_store_index_map(*args, **kwargs): | ||
return mma_store_32x8_to_shared_16x16_layout(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters