Skip to content

Commit

Permalink
[TL] Add TL Layout and Macro utils (#174)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* buf fix for matrix support

* lint fix

* dispatch tensor core based on shapes

* update install commands

* import scripts

* remove shared mem hack

* revert change for swizzling

* bug fix

* tl examples

* Enhance Swizzle

* lint fix

* test fix

* lint fix

* optimize layout

* update tl utils.

* macro optimization

* test fix
  • Loading branch information
LeiWang1999 authored Sep 4, 2024
1 parent 3aa9439 commit c15744e
Show file tree
Hide file tree
Showing 6 changed files with 344 additions and 11 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
10 changes: 10 additions & 0 deletions bitblas/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .utils import (
get_swizzle_layout, # noqa: F401
mma_store_index_map, # noqa: F401
get_ldmatrix_offset, # noqa: F401
)

from .macro_generator import TensorCorePTXMacroGenerator # noqa: F401
211 changes: 211 additions & 0 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import tvm.tl.language as T

from tvm import DataType
from tvm.runtime import convert
from .utils import (
mma_store_index_map,
get_ldmatrix_offset,
)

lift = convert


class TensorCorePTXMacroGenerator(object):
"""
To eliminate Python syntax within TIR Macro.
"""

M_DIM = 16
N_DIM = 16
WARP_SIZE = 32
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
}

def __init__(
self,
a_dtype="float16",
b_dtype="float16",
accum_dtype="float16",
a_transposed=False,
b_transposed=False,
block_row_warps=2,
block_col_warps=2,
warp_row_tiles=8,
warp_col_tiles=8,
chunk=16,
threads=128,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self._initialize_thread_axis(threads, self.WARP_SIZE, block_row_warps, block_col_warps)

def _initialize_k_dim(self, a_dtype="float16"):
self.k_dim = 256 // DataType(a_dtype).bits

def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size
self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size

def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]

def _initialize_mma_prefix(self, k_dim=16):
if k_dim == 16:
self.mma_prefix = "m16n8k16"
elif k_dim == 32:
self.mma_prefix = "m16n8k32"
else:
raise ValueError("Unsupported k_dim")

def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_x = m_dim
self.micro_size_y = n_dim
self.micro_size_k = k_dim

def _initialize_thread_axis(self,
threads=128,
warp_size=32,
block_row_warps=2,
block_col_warps=2):
self.threads = threads
# thread_bindings = T.env_thread("threadIdx.x")
# self.tx = thread_bindings % warp_size
# self.ty = (thread_bindings // warp_size) % block_row_warps
# self.tz = thread_bindings // (warp_size * block_row_warps)

@staticmethod
@T.macro
def MMA(inst, A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
T.ptx_mma(
inst.accum_dtype,
"m16n8k16",
"row",
"col",
inst.a_dtype_abbrv,
inst.b_dtype_abbrv,
inst.accum_dtype_abbrv,
A_local_buf.data,
i * inst.local_size_a,
B_local_buf.data,
j * inst.local_size_b,
C_local_buf.data,
i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out,
T.bool(False),
)

T.ptx_mma(
inst.accum_dtype,
"m16n8k16",
"row",
"col",
inst.a_dtype_abbrv,
inst.b_dtype_abbrv,
inst.accum_dtype_abbrv,
A_local_buf.data,
i * inst.local_size_a,
B_local_buf.data,
j * inst.local_size_b + lift(inst.local_size_b) // 2,
C_local_buf.data,
i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out +
lift(inst.local_size_out) // 2,
T.bool(False),
)

@staticmethod
@T.macro
def LDMATRIX_A(
inst,
A_local_buf,
A_shared_buf,
ki,
thread_bindings,
):
stride = inst.chunk
tx = thread_bindings % inst.WARP_SIZE
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps
# self.ty = (thread_bindings // warp_size) % block_row_warps
# self.tz = thread_bindings // (warp_size * block_row_warps)
for i in T.serial(inst.warp_rows):
T.ptx_ldmatrix(
"float16",
T.bool(False),
4,
".b16",
A_local_buf.data,
i * inst.local_size_a,
T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x,
ki * inst.micro_size_k,]),
get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, False),
)

@staticmethod
@T.macro
def LDMATRIX_B(
inst,
B_local_buf,
B_shared_buf,
ki,
thread_bindings,
):
stride = inst.chunk
tx = thread_bindings % inst.WARP_SIZE
tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)
for j in T.serial(inst.warp_cols):
T.ptx_ldmatrix(
"float16",
T.bool(False), # TODO(lei): should be optimized
4,
".b16",
B_local_buf.data,
j * inst.local_size_b,
T.address_of(B_shared_buf[tz * inst.warp_col_tiles + j * inst.micro_size_y,
ki * inst.micro_size_k,]),
get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, True),
)

# STS
# MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@staticmethod
@T.macro
def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings):
tx = thread_bindings % inst.WARP_SIZE
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps
tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
for local_id in T.serial(inst.local_size_out):
row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row,
col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) +
j * inst.local_size_out + local_id]
119 changes: 119 additions & 0 deletions bitblas/tl/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import arith
from tvm import DataType
from typing import Union, Literal


def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
ana = arith.Analyzer()
BANK_SIZE_BYTES = 128
if isinstance(dtype, str):
dtype = DataType(dtype)
col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % (
BANK_SIZE_BYTES // dtype.bits)
# use transaction bits to support diverse dtype.
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
coalescent_bits = dtype.bits * row_size
# permutation on 4 banks, each bank has 32 bits
bank_elems = BANK_SIZE_BYTES // dtype.bits
new_col_idx_outer = None
print(f"coalescent_bits: {coalescent_bits}")
if coalescent_bits % 1024 == 0:
# Use 8 * 8 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 32 banks
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
row_idx_sub = row_idx % bank_elems
new_col_idx_outer = col_idx_outer ^ row_idx_sub
else:
assert coalescent_bits % 512 == 0
# Use 8 * 4 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 16 banks
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub = row_idx % bank_elems
# Interleave elems per byte
interleave_elems = 32 // dtype.bits
new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems)

assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits"
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner)


def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
row = thread_id % 16
col = 8 * (thread_id // 16) + local_id % 8
return row, col


def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8)
col = 8 * ((thread_id % 16) // 8) + local_id % 8
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id):
row = thread_id % 16
col = local_id + (thread_id // 16) * 16
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id):
row = (thread_id // 16) * 8 + (thread_id % 8)
col = local_id + 16 * ((thread_id % 16) // 8)
return row, col


def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (local_id % 4 // 2) + (thread_id // 4)
col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2)
return row, col


def get_ldmatrix_offset(
matrix: Literal["A", "B"],
row_idx,
col_idx,
stride,
dtype: Literal["float16", "int8"] = "float16",
transpose: bool = False,
):
assert matrix in ["A", "B"], "matrix should be either A or B"
transform_func = (
ldmatrix_32x8_to_shared_16x16_layout
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_b)
transform_func_trans = (
ldmatrix_trans_32x8_to_shared_16x16_layout
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_a)
if matrix == "A":
assert not transpose, "A matrix should not be transposed"
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx)
return new_row_idx * stride + new_col_idx


def mma_store_index_map(*args, **kwargs):
return mma_store_32x8_to_shared_16x16_layout(*args, **kwargs)
10 changes: 3 additions & 7 deletions integration/BitNet/utils_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def activation_quant(self, x, num_bits=8):
Qp = 2**(num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp)
return result.type(torch.int8)
return result.type(torch.int8), s

@torch.compile
def post_quant_process(self, input, si, sw):
Expand All @@ -186,16 +186,14 @@ def native_forward(self, input):
return out

def forward_fp32_simulated(self, input):
quant_input = self.activation_quant(input, self.input_bits).detach()
quant_input, si = self.activation_quant(input, self.input_bits).detach()
quant_weight = self.weight_quant(self.weight).detach()

fp32_simulated_input = quant_input.float()
fp32_simulated_weight = quant_weight.float()
fp32_simulated_out = nn.functional.linear(fp32_simulated_input, fp32_simulated_weight)

sw = 1 / self.weight.abs().mean().clamp(min=1e-5)
Qp = 2**(self.input_bits - 1) - 1
si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
# if / (si * sw) it will inf in some cases
out = fp32_simulated_out / si
out = out / sw
Expand All @@ -206,11 +204,9 @@ def forward_fp32_simulated(self, input):

def forward(self, input):
# return self.forward_fp32_simulated(input)
quant_input = self.activation_quant(input, self.input_bits).detach()
quant_input, si = self.activation_quant(input, self.input_bits)
fp32_out = self.bitblas_matmul(quant_input, self.qweight)
sw = self.sw
Qp = self.Qp
si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
# if / (si * sw) it will inf in some cases
out = self.post_quant_process(fp32_out, si, sw)

Expand Down
3 changes: 0 additions & 3 deletions testing/python/tilelang/test_tilelang_dequantize_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def run_gemm(

print(f"output is {out}")

with open("debug/kernel.cu", "w") as f:
f.write(mod.mod.imported_modules[0].get_source())

def ref_program(A, qB):
import torch

Expand Down

0 comments on commit c15744e

Please sign in to comment.