Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DEV][TL] Support AMD Matrix Code Implementation #237

Merged
merged 23 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c4853ec
Refactor Simplify function to handle multiple functions in IRModule
LeiWang1999 Oct 16, 2024
9a21acf
Update submodule commit reference
LeiWang1999 Oct 17, 2024
f8d046b
Add CUDA_DEVICE_ORDER environment variable to bashrc
LeiWang1999 Oct 17, 2024
c1371dd
test fix
LeiWang1999 Oct 17, 2024
416cad2
lint fix
LeiWang1999 Oct 17, 2024
9209d1e
Refactor test_general_matmul_bf16.py to use bitblas.testing.main()
LeiWang1999 Oct 17, 2024
1cf7570
Update submodule commit reference
LeiWang1999 Oct 17, 2024
5fec040
Update Ubuntu version in install scripts based on LLVM version
LeiWang1999 Oct 18, 2024
4e1a0d2
Update Ubuntu version in install scripts based on LLVM version
LeiWang1999 Oct 18, 2024
fa85f8c
Update submodule commit reference
LeiWang1999 Oct 19, 2024
429d5b5
Update submodule commit reference
LeiWang1999 Oct 19, 2024
4003509
Update submodule commit reference
LeiWang1999 Oct 20, 2024
1d86582
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 20, 2024
df3af0d
Update submodule commit reference
LeiWang1999 Oct 28, 2024
1f1e027
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 28, 2024
732dda6
Update submodule commit reference
LeiWang1999 Oct 29, 2024
ebffbfa
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Oct 29, 2024
ff227fa
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 4, 2024
ac62936
[Dev] Update subproject commit for TVM
LeiWang1999 Nov 7, 2024
a7a239c
ignore profiler directories.
LeiWang1999 Nov 7, 2024
dcedbde
MFMA Support
LeiWang1999 Nov 7, 2024
e0b36f5
lint fix
LeiWang1999 Nov 7, 2024
fe668f9
Merge branch 'main' of https://github.com/microsoft/BitBLAS into amd_hip
LeiWang1999 Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,6 @@ models/frozenmodels/

# .bitblas_database
.bitblas_database

# rocprof workloads
workloads
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from be013f to c6be66
139 changes: 89 additions & 50 deletions bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,21 @@
import sys
import os

# installing tvm
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl")
sys.path.insert(0, install_tvm_path + "/python")

develop_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
develop_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", "")
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
os.environ["TL_TEMPLATE_PATH"] = os.path.join(install_tvm_path, "src/tl")
sys.path.insert(0, develop_tvm_path + "/python")

import tvm as tvm # noqa: E402
from . import gpu # noqa: F401
from .base import (
TileDevice, # noqa: F401
fast_tune, # noqa: F401
ApplyDefaultSchedule, # noqa: F401
ApplyFastTuning, # noqa: F401
BlockInfo, # noqa: F401
IterInfo, # noqa: F401
ScheduleRule, # noqa: F401
normalize_prim_func, # noqa: F401
try_inline, # noqa: F401
try_inline_contiguous_spatial, # noqa: F401
)

from . import testing # noqa: F401
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401
from .module import Linear # noqa: F401

import warnings
import functools
import logging
from tqdm import tqdm


class TqdmLoggingHandler(logging.Handler):
""" Custom logging handler that directs log output to tqdm progress bar to avoid interference. """
"""Custom logging handler that directs log output to tqdm progress bar to avoid interference."""

def __init__(self, level=logging.NOTSET):
""" Initialize the handler with an optional log level. """
"""Initialize the handler with an optional log level."""
super().__init__(level)

def emit(self, record):
""" Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted. """
"""Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted."""
try:
msg = self.format(record)
tqdm.write(msg)
Expand All @@ -67,8 +26,8 @@ def emit(self, record):


def set_log_level(level):
""" Set the logging level for the module's logger.
"""Set the logging level for the module's logger.

Args:
level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO).
OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
Expand All @@ -80,15 +39,17 @@ def set_log_level(level):


def _init_logger():
""" Initialize the logger specific for this module with custom settings and a Tqdm-based handler. """
"""Initialize the logger specific for this module with custom settings and a Tqdm-based handler."""
logger = logging.getLogger(__name__)
handler = TqdmLoggingHandler()
formatter = logging.Formatter(
fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False
set_log_level('WARNING')
set_log_level("WARNING")


_init_logger()
Expand All @@ -107,12 +68,90 @@ def new_func(*args, **kwargs):
warnings.warn(
f"Call to deprecated function {func.__name__} ({reason}).",
category=DeprecationWarning,
stacklevel=2)
stacklevel=2,
)
return func(*args, **kwargs)

return new_func

return decorator


logger = logging.getLogger(__name__)

# SETUP ENVIRONMENT VARIABLES
CUTLASS_NOT_FOUND_MESSAGE = ("CUTLASS is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."
TL_TEMPLATE_NOT_FOUND_MESSAGE = ("TileLang is not installed or found in the expected path")
", which may lead to compilation bugs when utilize tilelang backend."

# Handle TVM_IMPORT_PYTHON_PATH to import tvm from the specified path
TVM_IMPORT_PYTHON_PATH = os.environ.get("TVM_IMPORT_PYTHON_PATH", None)

if TVM_IMPORT_PYTHON_PATH is not None:
os.environ["PYTHONPATH"] = (TVM_IMPORT_PYTHON_PATH + ":" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, TVM_IMPORT_PYTHON_PATH + "/python")
else:
# remove the existing tvm path in PYTHONPATH
def remove_tvm_path(path):
return "tvm" in path

# installed 3rdparty tvm
install_tvm_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm")
if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = ":".join(
filter(remove_tvm_path,
os.environ.get("PYTHONPATH", "").split(":")))
sys.path = [path for path in sys.path if not remove_tvm_path(path)]

os.environ["PYTHONPATH"] = (
install_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, install_tvm_path + "/python")

# developed 3rdparty tvm
develop_tvm_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm")
if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path:
os.environ["PYTHONPATH"] = ":".join(
filter(remove_tvm_path,
os.environ.get("PYTHONPATH", "").split(":")))
sys.path = [path for path in sys.path if not remove_tvm_path(path)]
os.environ["PYTHONPATH"] = (
develop_tvm_path + "/python:" + os.environ.get("PYTHONPATH", ""))
sys.path.insert(0, develop_tvm_path + "/python")

if os.environ.get("TL_CUTLASS_PATH", None) is None:
install_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "3rdparty", "cutlass")
develop_cutlass_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "cutlass")
if os.path.exists(install_cutlass_path):
os.environ["TL_CUTLASS_PATH"] = install_cutlass_path + "/include"
elif (os.path.exists(develop_cutlass_path) and develop_cutlass_path not in sys.path):
os.environ["TL_CUTLASS_PATH"] = develop_cutlass_path + "/include"
else:
logger.warning(CUTLASS_NOT_FOUND_MESSAGE)

import tvm as tvm # noqa: E402
from . import gpu # noqa: F401
from .base import (
TileDevice, # noqa: F401
fast_tune, # noqa: F401
ApplyDefaultSchedule, # noqa: F401
ApplyFastTuning, # noqa: F401
BlockInfo, # noqa: F401
IterInfo, # noqa: F401
ScheduleRule, # noqa: F401
normalize_prim_func, # noqa: F401
try_inline, # noqa: F401
try_inline_contiguous_spatial, # noqa: F401
)

from . import testing # noqa: F401
from .utils import auto_detect_nvidia_target, apply_transform_on_input # noqa: F401
from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401
from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401
from .ops.general_flashatten import FlashAttenConfig, FlashAtten # noqa: F401
from .module import Linear # noqa: F401

__version__ = "0.0.1.dev15"
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
make_swizzle_layout,
)

from bitblas.tl.macro_generator import (
from bitblas.tl.mma_macro_generator import (
TensorCoreIntrinEmitter,
TensorCoreIntrinEmitterWithLadderTransform,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MatmulFineGrainScheduler,
MatmulWeightPropagationScheduler,
)
from bitblas.tl.macro_generator import (
from bitblas.tl.mma_macro_generator import (
INT4TensorCoreIntrinEmitter,
INT4TensorCoreIntrinEmitterWithLadderTransform,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
make_swizzle_layout, # noqa: F401
)

from bitblas.tl.macro_generator import (
from bitblas.tl.mma_macro_generator import (
TensorCoreIntrinEmitter, # noqa: F401
)
from bitblas.ops.common import TransformKind # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
index_to_coordinates, # noqa: F401
)

from bitblas.tl.macro_generator import (
from bitblas.tl.mma_macro_generator import (
INT4TensorCoreIntrinEmitter, # noqa: F401
)
from bitblas.base.arch import TileDevice
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
make_swizzle_layout, # noqa: F401
)
from .finegrained_primitive_tensorcore import MatmulDequantizeFineGrainedScheduler
from bitblas.tl.macro_generator import (
from bitblas.tl.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401
)
from bitblas.ops.common import TransformKind # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from bitblas.base.arch import TileDevice
from bitblas.base.roller.hint import Hint
from bitblas.tl.macro_generator import (
from bitblas.tl.mma_macro_generator import (
INT4TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401
)
from bitblas.ops.common import TransformKind # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion bitblas/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
get_ldmatrix_offset, # noqa: F401
)

from .macro_generator import (
from .mma_macro_generator import (
TensorCoreIntrinEmitter, # noqa: F401
TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401
)
12 changes: 12 additions & 0 deletions bitblas/tl/base_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


def make_shared_to_local_linear_layout_2d(i, j, stride=16, local_size=4):

def shared_to_local_linear_layout_2d(i, j):
thread_id = j + (i // local_size) * stride
local = (i % local_size)
return thread_id, local

return shared_to_local_linear_layout_2d
80 changes: 80 additions & 0 deletions bitblas/tl/mfma_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from tvm.runtime import convert


def shared_16x4_to_local_64x1_layout_A(i, j):
thread_id = (j * 16 + i)
return thread_id, convert(0)


def thread_id_shared_access_64x1_to_16x4_layout_A(thread_id, local_id):
i = thread_id % 16
j = thread_id // 16
return i, j


def shared_4x16_to_local_64x1_layout_B(i, j):
thread_id = (i * 16 + j)
return thread_id, convert(0)


def thread_id_shared_access_64x1_to_4x16_layout_B(thread_id, local_id):
i = thread_id // 16
j = thread_id % 16
return i, j


def shared_16x16_to_local_64x4_layout_C(i, j):
thread_id = j + (i // 4) * 16
local = (i % 4)
return thread_id, local


def shared_16x16_to_ldmatrix_64x4_layout(ind):
i, j = ind[0], ind[1]
thread_id, local_id = shared_16x16_to_local_64x4_layout_C(i, j)
return convert([thread_id, local_id])


def thread_id_shared_access_64x4_to_16x16_layout_A(thread_id, local_id):
i = thread_id % 16
j = (thread_id // 16) * 4 + local_id
return i, j


def shared_16x16_to_local_64x4_layout_A(i, j):
thread_id = i + 16 * (j // 4)
local = (j % 4)
return thread_id, local


def thread_id_shared_access_64x4_to_16x16_layout_B(thread_id, local_id):
i = local_id + (thread_id // 16) * 4
j = thread_id % 16
return i, j


def shared_16x16_to_local_64x4_layout_B(i, j):
thread_id = j + (i // 4) * 16
local = (i % 4)
return thread_id, local


def thread_id_shared_access_64x4_to_16x16_layout_C_smooth(thread_id, local_id):
i = thread_id % 16
j = local_id + (thread_id // 16) * 4
return j, i


def thread_id_shared_access_64x4_to_16x16_layout_C(thread_id, local_id):
# This is a hacky implementation to simulate the performance
is_smooth = os.environ.get("TILE_LANG_SMOOTH_LAYOUT") == "1"
print(is_smooth)
if is_smooth:
return thread_id_shared_access_64x4_to_16x16_layout_C_smooth(thread_id, local_id)

i = local_id + (thread_id // 16) * 4
j = thread_id % 16
return i, j
Loading
Loading