diff --git a/.gitignore b/.gitignore index ca788f982..937edbbae 100644 --- a/.gitignore +++ b/.gitignore @@ -76,3 +76,6 @@ models/frozenmodels/ # .bitblas_database .bitblas_database + +# rocprof workloads +workloads diff --git a/3rdparty/tvm b/3rdparty/tvm index be013f6d5..c6be66d56 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit be013f6d5e623e1787351aac897e270970e33ada +Subproject commit c6be66d563695bfbaf4f3d46d312e82b6ad9be1d diff --git a/bitblas/__init__.py b/bitblas/__init__.py index 3074e3fcb..661556c56 100644 --- a/bitblas/__init__.py +++ b/bitblas/__init__.py @@ -3,47 +3,6 @@ import sys import os -# installing tvm -install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") -install_cutlass_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") -if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: - os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") - os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" - os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl") - sys.path.insert(0, install_tvm_path + "/python") - -develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") -develop_cutlass_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") -if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: - os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "") - os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" - os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl") - sys.path.insert(0, develop_tvm_path + "/python") - -import tvm as tvm # noqa: E402 -from . import gpu # noqa: F401 -from .base import ( - TileDevice, # noqa: F401 - fast_tune, # noqa: F401 - ApplyDefaultSchedule, # noqa: F401 - ApplyFastTuning, # noqa: F401 - BlockInfo, # noqa: F401 - IterInfo, # noqa: F401 - ScheduleRule, # noqa: F401 - normalize_prim_func, # noqa: F401 - try_inline, # noqa: F401 - try_inline_contiguous_spatial, # noqa: F401 -) - -from . import testing # noqa: F401 -from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 -from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 -from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 -from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401 -from .module import Linear # noqa: F401 - import warnings import functools import logging @@ -51,14 +10,14 @@ class TqdmLoggingHandler(logging.Handler): - """ Custom logging handler that directs log output to tqdm progress bar to avoid interference. """ + """Custom logging handler that directs log output to tqdm progress bar to avoid interference.""" def __init__(self, level=logging.NOTSET): - """ Initialize the handler with an optional log level. """ + """Initialize the handler with an optional log level.""" super().__init__(level) def emit(self, record): - """ Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted. """ + """Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted.""" try: msg = self.format(record) tqdm.write(msg) @@ -67,8 +26,8 @@ def emit(self, record): def set_log_level(level): - """ Set the logging level for the module's logger. - + """Set the logging level for the module's logger. + Args: level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO). OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' @@ -80,15 +39,17 @@ def set_log_level(level): def _init_logger(): - """ Initialize the logger specific for this module with custom settings and a Tqdm-based handler. """ + """Initialize the logger specific for this module with custom settings and a Tqdm-based handler.""" logger = logging.getLogger(__name__) handler = TqdmLoggingHandler() formatter = logging.Formatter( - fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) handler.setFormatter(formatter) logger.addHandler(handler) logger.propagate = False - set_log_level('WARNING') + set_log_level("WARNING") _init_logger() @@ -107,7 +68,8 @@ def new_func(*args, **kwargs): warnings.warn( f"Call to deprecated function {func.__name__} ({reason}).", category=DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) return func(*args, **kwargs) return new_func @@ -115,4 +77,81 @@ def new_func(*args, **kwargs): return decorator +logger = logging.getLogger(__name__) + +# SETUP ENVIRONMENT VARIABLES +CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path") +", which may lead to compilation bugs when utilize tilelang backend." +TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path") +", which may lead to compilation bugs when utilize tilelang backend." + +# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path +TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None) + +if TVM_IMPORT_PYTHON_PATH is not None: + os.environ["PYTHONPATH"] = (TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", "")) + sys.path.insert(0, TVM_IMPORT_PYTHON_PATH + "/python") +else: + # remove the existing tvm path in PYTHONPATH + def remove_tvm_path(path): + return "tvm" in path + + # installed 3rdparty tvm + install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm") + if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: + os.environ["PYTHONPATH"] = ":".join( + filter(remove_tvm_path, + os.environ.get("PYTHONPATH", "").split(":"))) + sys.path = [path for path in sys.path if not remove_tvm_path(path)] + + os.environ["PYTHONPATH"] = ( + install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) + sys.path.insert(0, install_tvm_path + "/python") + + # developed 3rdparty tvm + develop_tvm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm") + if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: + os.environ["PYTHONPATH"] = ":".join( + filter(remove_tvm_path, + os.environ.get("PYTHONPATH", "").split(":"))) + sys.path = [path for path in sys.path if not remove_tvm_path(path)] + os.environ["PYTHONPATH"] = ( + develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")) + sys.path.insert(0, develop_tvm_path + "/python") + +if os.environ.get("TL_CUTLASS_PATH", None) is None: + install_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass") + develop_cutlass_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass") + if os.path.exists(install_cutlass_path): + os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include" + elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path): + os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include" + else: + logger.warning(CUTLASS_NOT_FOUND_MESSAGE) + +import tvm as tvm # noqa: E402 +from . import gpu # noqa: F401 +from .base import ( + TileDevice, # noqa: F401 + fast_tune, # noqa: F401 + ApplyDefaultSchedule, # noqa: F401 + ApplyFastTuning, # noqa: F401 + BlockInfo, # noqa: F401 + IterInfo, # noqa: F401 + ScheduleRule, # noqa: F401 + normalize_prim_func, # noqa: F401 + try_inline, # noqa: F401 + try_inline_contiguous_spatial, # noqa: F401 +) + +from . import testing # noqa: F401 +from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401 +from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 +from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 +from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401 +from .module import Linear # noqa: F401 + __version__ = "0.0.1.dev15" diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index e0d40e6bd..ac3f0435d 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -9,7 +9,7 @@ make_swizzle_layout, ) -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitter, TensorCoreIntrinEmitterWithLadderTransform, ) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py index 8e55b0231..65d164c20 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py @@ -13,7 +13,7 @@ MatmulFineGrainScheduler, MatmulWeightPropagationScheduler, ) -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitter, INT4TensorCoreIntrinEmitterWithLadderTransform, ) diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py index 060d52b1e..9127c7ae4 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore.py @@ -9,7 +9,7 @@ make_swizzle_layout, # noqa: F401 ) -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitter, # noqa: F401 ) from bitblas.ops.common import TransformKind # noqa: F401 diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py index 100ab0a31..54303167e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/finegrained_primitive_tensorcore_s4.py @@ -10,7 +10,7 @@ index_to_coordinates, # noqa: F401 ) -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitter, # noqa: F401 ) from bitblas.base.arch import TileDevice diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py index 4652566c6..d51766cec 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore.py @@ -9,7 +9,7 @@ make_swizzle_layout, # noqa: F401 ) from .finegrained_primitive_tensorcore import MatmulDequantizeFineGrainedScheduler -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 ) from bitblas.ops.common import TransformKind # noqa: F401 diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py index fef550bd7..ace3052da 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/ladder_weight_transform_tensorcore_s4.py @@ -11,7 +11,7 @@ ) from bitblas.base.arch import TileDevice from bitblas.base.roller.hint import Hint -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 ) from bitblas.ops.common import TransformKind # noqa: F401 diff --git a/bitblas/tl/__init__.py b/bitblas/tl/__init__.py index 919b70662..3103fbf89 100644 --- a/bitblas/tl/__init__.py +++ b/bitblas/tl/__init__.py @@ -7,7 +7,7 @@ get_ldmatrix_offset, # noqa: F401 ) -from .macro_generator import ( +from .mma_macro_generator import ( TensorCoreIntrinEmitter, # noqa: F401 TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401 ) diff --git a/bitblas/tl/base_layout.py b/bitblas/tl/base_layout.py new file mode 100644 index 000000000..b60768c8e --- /dev/null +++ b/bitblas/tl/base_layout.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +def make_shared_to_local_linear_layout_2d(i, j, stride=16, local_size=4): + + def shared_to_local_linear_layout_2d(i, j): + thread_id = j + (i // local_size) * stride + local = (i % local_size) + return thread_id, local + + return shared_to_local_linear_layout_2d diff --git a/bitblas/tl/mfma_layout.py b/bitblas/tl/mfma_layout.py new file mode 100644 index 000000000..a7302e897 --- /dev/null +++ b/bitblas/tl/mfma_layout.py @@ -0,0 +1,80 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import os +from tvm.runtime import convert + + +def shared_16x4_to_local_64x1_layout_A(i, j): + thread_id = (j * 16 + i) + return thread_id, convert(0) + + +def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id): + i = thread_id % 16 + j = thread_id // 16 + return i, j + + +def shared_4x16_to_local_64x1_layout_B(i, j): + thread_id = (i * 16 + j) + return thread_id, convert(0) + + +def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id): + i = thread_id // 16 + j = thread_id % 16 + return i, j + + +def shared_16x16_to_local_64x4_layout_C(i, j): + thread_id = j + (i // 4) * 16 + local = (i % 4) + return thread_id, local + + +def shared_16x16_to_ldmatrix_64x4_layout(ind): + i, j = ind[0], ind[1] + thread_id, local_id = shared_16x16_to_local_64x4_layout_C(i, j) + return convert([thread_id, local_id]) + + +def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id): + i = thread_id % 16 + j = (thread_id // 16) * 4 + local_id + return i, j + + +def shared_16x16_to_local_64x4_layout_A(i, j): + thread_id = i + 16 * (j // 4) + local = (j % 4) + return thread_id, local + + +def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id): + i = local_id + (thread_id // 16) * 4 + j = thread_id % 16 + return i, j + + +def shared_16x16_to_local_64x4_layout_B(i, j): + thread_id = j + (i // 4) * 16 + local = (i % 4) + return thread_id, local + + +def thread_id_shared_access_64x4_to_16x16_layout_C_smooth(thread_id, local_id): + i = thread_id % 16 + j = local_id + (thread_id // 16) * 4 + return j, i + + +def thread_id_shared_access_64x4_to_16x16_layout_C(thread_id, local_id): + # This is a hacky implementation to simulate the performance + is_smooth = os.environ.get("TILE_LANG_SMOOTH_LAYOUT") == "1" + print(is_smooth) + if is_smooth: + return thread_id_shared_access_64x4_to_16x16_layout_C_smooth(thread_id, local_id) + + i = local_id + (thread_id // 16) * 4 + j = thread_id % 16 + return i, j diff --git a/bitblas/tl/mfma_macro_generator.py b/bitblas/tl/mfma_macro_generator.py new file mode 100644 index 000000000..b6ccc7b2a --- /dev/null +++ b/bitblas/tl/mfma_macro_generator.py @@ -0,0 +1,311 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tvm.tl.language as T + +from tvm import DataType +from tvm.runtime import convert +from .utils import ( + mfma_store_index_map,) + +lift = convert + + +class MatrixCoreIntrinEmitter(object): + """ + To eliminate Python syntax within TIR Macro. + """ + + M_DIM = 16 + N_DIM = 16 + WARP_SIZE = 64 + dtype_abbrv = { + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "int8": "int8", + "int32": "int32", + "e4m3_float8": "e4m3", + "e5m2_float8": "e5m2", + } + + def __init__( + self, + a_dtype="float16", + b_dtype="float16", + accum_dtype="float16", + a_transposed=False, + b_transposed=False, + block_row_warps=2, + block_col_warps=2, + warp_row_tiles=8, + warp_col_tiles=8, + chunk=16, + reduce_k=1, + num_elems_per_byte=1, + ): + self.a_dtype = a_dtype + self.b_dtype = b_dtype + self.accum_dtype = accum_dtype + self.a_transposed = a_transposed + self.b_transposed = b_transposed + # Hint Information + self.block_row_warps = block_row_warps + self.block_col_warps = block_col_warps + self.warp_row_tiles = warp_row_tiles + self.warp_col_tiles = warp_col_tiles + self.chunk = chunk + self._initialize_k_dim(a_dtype) + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) + self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) + self._initialize_mfma_prefix(self.k_dim) + self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) + self.warp_rows = warp_row_tiles // self.micro_size_x + self.warp_cols = warp_col_tiles // self.micro_size_y + self.reduce_k = reduce_k + self.threads = (self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k) + self.num_elems_per_byte = num_elems_per_byte + + def _initialize_k_dim(self, a_dtype="float16"): + if isinstance(a_dtype, str): + a_dtype = DataType(a_dtype) + if a_dtype.bits == 32: + self.k_dim = 4 + elif a_dtype.bits in [16, 8]: + self.k_dim = 16 + else: + raise ValueError(f"Unsupported a_dtype = {a_dtype}") + + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): + self.local_size_a = (m_dim * k_dim) // warp_size + self.local_size_b = (n_dim * k_dim) // warp_size + self.local_size_out = (m_dim * n_dim) // warp_size + + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): + self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] + self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] + self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] + + def _initialize_mfma_prefix(self, k_dim=16): + in_dtype, out_dtype = self.a_dtype, self.accum_dtype + M_DIM, N_DIM = self.M_DIM, self.N_DIM + out_dtype_abbrv = { + "float16": "f16", + "float32": "f32", + "int8": "i8", + "int32": "i32" + }[out_dtype] + + in_dtype_abbrv = { + "float16": "f16", + "float32": "f32", + "int8": "i8", + "int32": "i32" + }[in_dtype] + + self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}" + + def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): + self.micro_size_x = m_dim + self.micro_size_y = n_dim + self.micro_size_k = k_dim + + def get_ldmatrix_index_map(self, is_b=False): + from .mfma_layout import ( + shared_16x4_to_local_64x1_layout_A, + shared_4x16_to_local_64x1_layout_B, + shared_16x16_to_local_64x4_layout_A, + shared_16x16_to_local_64x4_layout_B, + thread_id_shared_access_64x1_to_16x4_layout_A, + thread_id_shared_access_64x1_to_4x16_layout_B, + thread_id_shared_access_64x4_to_16x16_layout_A, + thread_id_shared_access_64x4_to_16x16_layout_B, + ) + + k_dim = self.k_dim + transposed = self.a_transposed if not is_b else self.b_transposed + if k_dim == 4: + index_map = shared_16x4_to_local_64x1_layout_A + reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A + if is_b: + index_map = shared_16x4_to_local_64x1_layout_A if transposed else shared_4x16_to_local_64x1_layout_B + reverse_index_map = thread_id_shared_access_64x1_to_16x4_layout_A if transposed else thread_id_shared_access_64x1_to_4x16_layout_B + elif k_dim == 16: + index_map = shared_16x16_to_local_64x4_layout_B if transposed else shared_16x16_to_local_64x4_layout_A + reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_B if transposed else thread_id_shared_access_64x4_to_16x16_layout_A + + if is_b: + index_map = shared_16x16_to_local_64x4_layout_A if transposed else shared_16x16_to_local_64x4_layout_B + reverse_index_map = thread_id_shared_access_64x4_to_16x16_layout_A if transposed else thread_id_shared_access_64x4_to_16x16_layout_B + else: + raise ValueError("k_dim must be 4 or 16 currently") + + return index_map, reverse_index_map + + def ldmatrix_a(self, A_local_buf, A_shared_buf, ki, thread_bindings, rk=0): + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_x = self.micro_size_x + micro_size_k = self.micro_size_k + local_size_a = self.local_size_a + is_transposed = self.a_transposed + + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=False) + + @T.macro + def _warp_ldmatrix_a( + A_local_buf, + A_shared_buf, + ki, + thread_bindings, + rk=0, + ): + tx = thread_bindings % WARP_SIZE + tz = (thread_bindings // (WARP_SIZE * block_col_warps)) % block_row_warps + if is_transposed: + for i in T.serial(warp_cols): + for local_id in T.vectorized(local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = (rk * chunk + ki * micro_size_k, + tz * warp_col_tiles + i * micro_size_x) + A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col] + else: + for i in T.serial(warp_cols): + for local_id in T.vectorized(local_size_a): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = (tz * warp_col_tiles + i * micro_size_x, + rk * chunk + ki * micro_size_k) + A_local_buf[i * local_size_a + local_id] = A_shared_buf[l + row, r + col] + + return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_bindings, rk) + + def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0): + + WARP_SIZE = self.WARP_SIZE + block_col_warps = self.block_col_warps + warp_col_tiles = self.warp_col_tiles + warp_cols = self.warp_cols + chunk = self.chunk + micro_size_y = self.micro_size_y + micro_size_k = self.micro_size_k + local_size_b = self.local_size_b + is_transposed = self.b_transposed + + _, reverse_index_map = self.get_ldmatrix_index_map(is_b=True) + + @T.macro + def _warp_ldmatrix_b( + B_local_buf, + B_shared_buf, + ki, + thread_bindings, + rk=0, + ): + tx = thread_bindings % WARP_SIZE + ty = (thread_bindings // WARP_SIZE) % block_col_warps + + if is_transposed: + for j in T.serial(warp_cols): + for local_id in T.vectorized(local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + ty * warp_col_tiles + j * micro_size_y, + rk * chunk + ki * micro_size_k, + ) + B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col] + else: + for j in T.serial(warp_cols): + for local_id in T.vectorized(local_size_b): + row, col = T.meta_var(reverse_index_map(tx, local_id)) + l, r = ( + rk * chunk + ki * micro_size_k, + ty * warp_col_tiles + j * micro_size_y, + ) + B_local_buf[j * local_size_b + local_id] = B_shared_buf[l + row, r + col] + + return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_bindings, rk) + + def mfma(self, A_local_buf, B_local_buf, C_local_buf): + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_a = self.local_size_a + local_size_b = self.local_size_b + local_size_out = self.local_size_out + mfma_suffix = self.mfma_suffix + a_dtype, b_dtype, out_dtype = self.a_dtype, self.b_dtype, self.accum_dtype + compute_a_dtype = a_dtype if local_size_a == 1 else f"{a_dtype}x{local_size_a}" + compute_b_dtype = b_dtype if local_size_b == 1 else f"{b_dtype}x{local_size_b}" + compute_out_dtype = out_dtype if local_size_out == 1 else f"{out_dtype}x{local_size_out}" + + @T.macro + def _warp_mma(A_local_buf, B_local_buf, C_local_buf): + for i, j in T.grid(warp_rows, warp_cols): + T.tvm_mfma( + mfma_suffix, + "row", + "row", + compute_a_dtype, + compute_b_dtype, + compute_out_dtype, + A_local_buf.data, + (i * local_size_a) // local_size_a, + B_local_buf.data, + (j * local_size_b) // local_size_b, + C_local_buf.data, + (i * warp_cols * local_size_out + j * local_size_out) // local_size_out, + dtype=compute_out_dtype, + ) + + return _warp_mma(A_local_buf, B_local_buf, C_local_buf) + + def stmatrix(self, C_local_buf, C_buf, thread_bindings, pid_m=None, pid_n=None): + WARP_SIZE = self.WARP_SIZE + block_row_warps = self.block_row_warps + block_col_warps = self.block_col_warps + warp_rows = self.warp_rows + warp_cols = self.warp_cols + local_size_out = self.local_size_out + + is_global = pid_m is not None and pid_n is not None + BLOCK_M = block_row_warps * warp_rows + BLOCK_N = block_col_warps * warp_cols + M_DIM, N_DIM = self.M_DIM, self.N_DIM + + # STS + # MMA Store must be in simulated instead of TVM Intrins + # As TVM Intrins is like a hack that the threadIdx.x should be always + # equal to the warp_size + @T.macro + def _warp_stmatrix_shared(C_local_buf, C_buf, thread_bindings): + tx = thread_bindings % WARP_SIZE + ty = (thread_bindings // WARP_SIZE) % block_row_warps + tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps + + for i, j in T.grid(warp_rows, warp_cols): + for local_id in T.serial(local_size_out): + row, col = T.meta_var(mfma_store_index_map(tx, local_id)) + C_buf[ty * warp_rows + i, tz * warp_cols + j, row, + col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + + local_id] + + @T.macro + def _warp_stmatrix_global(C_local_buf, C_buf, thread_bindings): + tx = thread_bindings % WARP_SIZE + ty = (thread_bindings // WARP_SIZE) % block_row_warps + tz = (thread_bindings // (WARP_SIZE * block_row_warps)) % block_col_warps + + for i, j in T.grid(warp_rows, warp_cols): + for local_id in T.serial(local_size_out): + row, col = T.meta_var(mfma_store_index_map(tx, local_id)) + C_buf[(pid_m * BLOCK_M + tz * warp_rows + i) * M_DIM + row, + (pid_n * BLOCK_N + ty * warp_cols + j) * N_DIM + + col] = C_local_buf[i * warp_cols * local_size_out + j * local_size_out + + local_id] + + return _warp_stmatrix_global(C_local_buf, C_buf, + thread_bindings) if is_global else _warp_stmatrix_shared( + C_local_buf, C_buf, thread_bindings) diff --git a/bitblas/tl/macro_generator.py b/bitblas/tl/mma_macro_generator.py similarity index 100% rename from bitblas/tl/macro_generator.py rename to bitblas/tl/mma_macro_generator.py diff --git a/bitblas/tl/utils.py b/bitblas/tl/utils.py index 18f0d3274..9f354f8ce 100644 --- a/bitblas/tl/utils.py +++ b/bitblas/tl/utils.py @@ -12,6 +12,8 @@ ldmatrix_16x32_to_shared_16x32_layout_b, mma_store_32x8_to_shared_16x16_layout, ) +from .mfma_layout import ( + thread_id_shared_access_64x4_to_16x16_layout_C,) def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]): @@ -110,6 +112,10 @@ def mma_store_index_map(*args, **kwargs): return mma_store_32x8_to_shared_16x16_layout(*args, **kwargs) +def mfma_store_index_map(*args, **kwargs): + return thread_id_shared_access_64x4_to_16x16_layout_C(*args, **kwargs) + + def get_mma_micro_size(dtype: Literal["float16", "int8"]): # TODO(lei): FP8 related precision support. # Basic Tensor Core Matrix Multiply operation Unit diff --git a/integration/BitNet/int4_kernel/tl_int4xint2.py b/integration/BitNet/int4_kernel/tl_int4xint2.py index 670f72b07..7fc2fc7a9 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint2.py +++ b/integration/BitNet/int4_kernel/tl_int4xint2.py @@ -11,7 +11,7 @@ from bitblas.tl.utils import (make_swizzle_layout, index_to_coordinates) from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitter,) from bitblas.ops.base_scheduler import simplify_prim_func diff --git a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py index e879f1524..c7c80a3f1 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py +++ b/integration/BitNet/int4_kernel/tl_int4xint2_ladder_weight_only.py @@ -9,7 +9,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import make_swizzle_layout, index_to_coordinates -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitterWithLadderTransform,) from bitblas.gpu.intrin.lop3 import decode_i2s_to_i4s from bitblas.ops.base_scheduler import simplify_prim_func diff --git a/integration/BitNet/int4_kernel/tl_int4xint4.py b/integration/BitNet/int4_kernel/tl_int4xint4.py index 5b040db89..e3bc20649 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint4.py +++ b/integration/BitNet/int4_kernel/tl_int4xint4.py @@ -9,7 +9,7 @@ from bitblas.tl.utils import ( make_swizzle_layout,) -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitter,) from bitblas.ops.base_scheduler import simplify_prim_func diff --git a/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py index b0e0c4d5d..6f8a8dcce 100644 --- a/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py +++ b/integration/BitNet/int4_kernel/tl_int4xint4_ladder_weight_only.py @@ -9,7 +9,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import make_swizzle_layout -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitterWithLadderTransform,) from bitblas.ops.base_scheduler import simplify_prim_func diff --git a/integration/BitNet/int4_kernel/tl_int8xint8.py b/integration/BitNet/int4_kernel/tl_int8xint8.py index e809c673e..3a5583094 100644 --- a/integration/BitNet/int4_kernel/tl_int8xint8.py +++ b/integration/BitNet/int4_kernel/tl_int8xint8.py @@ -8,7 +8,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitter,) from bitblas.ops.base_scheduler import simplify_prim_func diff --git a/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py b/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py index 733441f2f..be1f7ea56 100644 --- a/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py +++ b/integration/BitNet/int4_kernel/tl_int8xint8_ladder_weight_only.py @@ -9,7 +9,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform,) from bitblas.ops.base_scheduler import simplify_prim_func diff --git a/testing/python/tilelang/test_tilelang_dequantize_gemm.py b/testing/python/tilelang/test_tilelang_dequantize_gemm.py index 620ef5be7..dd63274e2 100644 --- a/testing/python/tilelang/test_tilelang_dequantize_gemm.py +++ b/testing/python/tilelang/test_tilelang_dequantize_gemm.py @@ -9,7 +9,7 @@ import tvm.tl.language as T from bitblas.quantization import _tir_packed_to_unsigned_convert from bitblas.tl.utils import (make_swizzle_layout) -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitterWithLadderTransform,) from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 diff --git a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py index 0dfe07633..f02fcfbe1 100644 --- a/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py +++ b/testing/python/tilelang/test_tilelang_dyanmic_symbolic.py @@ -9,7 +9,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout -from bitblas.tl.macro_generator import (TensorCoreIntrinEmitter) +from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter) torch.manual_seed(0) diff --git a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py index 37c210b91..ee93d33b0 100644 --- a/testing/python/tilelang/test_tilelang_gemm_s4_mma.py +++ b/testing/python/tilelang/test_tilelang_gemm_s4_mma.py @@ -11,7 +11,7 @@ from bitblas.tl.utils import ( make_swizzle_layout,) -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( INT4TensorCoreIntrinEmitter, INT4TensorCoreIntrinEmitterWithLadderTransform, ) diff --git a/testing/python/tilelang/test_tilelang_macro_gemm.py b/testing/python/tilelang/test_tilelang_macro_gemm.py index 4c4cf8f59..c3fcce6a1 100644 --- a/testing/python/tilelang/test_tilelang_macro_gemm.py +++ b/testing/python/tilelang/test_tilelang_macro_gemm.py @@ -9,7 +9,7 @@ from tvm import tl as TL import tvm.tl.language as T from bitblas.tl.utils import get_swizzle_layout -from bitblas.tl.macro_generator import ( +from bitblas.tl.mma_macro_generator import ( TensorCoreIntrinEmitter, TensorCoreIntrinEmitterWithLadderTransform, )