Skip to content

Commit

Permalink
[Dev] Enhance TileLang Backend and fix a bug for INT4xINT2 (#236)
Browse files Browse the repository at this point in the history
* Enhance installation script with detailed logging and error handling

* Implement INT4 Related Scheduler.

* Support INT4 As an operator

* Update INT4 Matrix multiplication

* fine tune

* lint fix

* update readme.

* Refactor function retrieval from IRModule and update related usages

* enhance

* lint fix

---------

Co-authored-by: LeiWang1999 <leiwaang1999@Outlook.com>
  • Loading branch information
LeiWang1999 and LeiWang1999 authored Nov 5, 2024
1 parent 0fa8e37 commit 04ffcc3
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 35 deletions.
4 changes: 2 additions & 2 deletions bitblas/base/common_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Callable, List

from tvm import tir

from bitblas.utils import retrieve_func_from_module
from .analysis import BlockInfo


Expand Down Expand Up @@ -74,7 +74,7 @@ def get_output_blocks(
"""

# collect arguments buffer
func = sch.mod["main"]
func = retrieve_func_from_module(sch.mod)
args = list(func.buffer_map.values())

output_blocks = []
Expand Down
21 changes: 18 additions & 3 deletions bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
import tempfile
import itertools
from tvm.ir.supply import GlobalVarSupply
from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
from bitblas.utils import (
tensor_replace_dp4a,
tensor_remove_make_int4,
tensor_remove_make_int2,
retrieve_func_from_module,
)
from bitblas.utils.tensor_adapter import (
np_float2np_bf16,)
import logging
Expand Down Expand Up @@ -58,18 +63,28 @@ def __init__(self, config, sch, mod: Module):
self.time_evaluator = None

def profile(self, data_distribution="uniform"):
func = self.sch.mod["main"]
func = retrieve_func_from_module(self.sch.mod)
device = self.config.arch.device
profile_tensors = get_dummy_input_arrays(func, device, distribution=data_distribution)
latency = self.time_evaluator(*profile_tensors).mean * 1e3
return latency


def get_roller_hints_from_func(func: tir.PrimFunc,
def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
arch: TileDevice,
topk: int = 10,
tensorcore_only: bool = False,
allow_gemv: bool = False) -> Optional[List[Hint]]:
func = None
if isinstance(func_or_module, tir.PrimFunc):
func = func_or_module
elif isinstance(func_or_module, IRModule):
func = retrieve_func_from_module(func_or_module)
else:
raise ValueError("Not supported type: ", type(func_or_module))

assert func is not None, "The function should not be None"

if tensorcore_only:
try:
tensorized_func, tags = get_tensorized_func_and_tags(
Expand Down
27 changes: 15 additions & 12 deletions bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,9 +431,8 @@ def is_dequantize(block: BlockRV) -> bool:
has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads)
if not has_uint_input:
return False
if len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype):
return False
return True
return not (len(block_stmt.writes) != 1 or
"float" not in str(block_stmt.writes[0].buffer.dtype))

dequantize_blocks = [block for block in blocks if is_dequantize(block)]
return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None
Expand Down Expand Up @@ -552,9 +551,7 @@ def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool:
len(
collect_block_iter_vars_used_in_access_region(block_stmt,
block_stmt.writes[0].region)) > 0)
if not all(conditions):
return False
return True
return all(conditions)

# step2. transform function to tensorcore matmul (e.g. conv2d with im2col)
def check_sm_version(arch: str) -> int:
Expand Down Expand Up @@ -677,14 +674,20 @@ def check_last_trait(region: List[Range]):
block_stmt = sch.get(main_block)

# 16 for 16 bits tensor core while 32 for 8bits tensorcore.
minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
minimal_tensorize_spatial_threshold = 16
minimal_tensorize_reduce_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
# the batch dimension is not taken into consideration.
extent = block_stmt.iter_vars[1].dom.extent
if isinstance(extent, tir.expr.IntImm) and (extent.value < (1 if allow_gemv else
minimal_tensorize_threshold)):
return func, None
for item_var in block_stmt.iter_vars[2:]:
for item_var in block_stmt.iter_vars[1:]:
extent = item_var.dom.extent
iter_type = item_var.iter_type

if iter_type is IterVar.DataPar:
minimal_tensorize_threshold = minimal_tensorize_spatial_threshold
elif iter_type is IterVar.CommReduce:
minimal_tensorize_threshold = minimal_tensorize_reduce_threshold
else:
raise ValueError(f"Unknown IterVar type {iter_type}")

if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold):
return func, None
tags = analysis_tensorcore_tags(sch, main_block, target)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down Expand Up @@ -340,7 +340,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down
51 changes: 49 additions & 2 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore_s4.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,17 @@ class MatmulINT4FineGrainScheduler(MatmulFineGrainScheduler):

def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}"
M = self.M
K = self.K // 2 # 2xint4 should be packed into one single int8
# Simple TIR Compute Expression
storage_dtype = "int8"

# This is a hack to utilize tensor core
if isinstance(M, int) and M < 16:
M = 16

ir_module = matmul_select_implementation(
M=self.M,
M=M,
N=self.N,
K=K,
in_dtype=storage_dtype,
Expand All @@ -46,7 +52,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down Expand Up @@ -230,6 +236,47 @@ def __post_init__(self):
@dataclass
class MatmulINT4WeightPropagationScheduler(MatmulWeightPropagationScheduler):

def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}"
M = self.M
K = self.K // 2 # 2xint4 should be packed into one single int8
# Simple TIR Compute Expression
storage_dtype = "int8"

# This is a hack to utilize tensor core
if isinstance(M, int) and M < 16:
M = 16

ir_module = matmul_select_implementation(
M=M,
N=self.N,
K=K,
in_dtype=storage_dtype,
out_dtype=self.out_dtype,
accum_dtype=self.accum_dtype,
layout=layout,
propagate_b=self.weight_transform_kind)

roller_hints = get_roller_hints_from_func(
ir_module,
arch,
topk,
tensorcore_only=True,
allow_gemv=True,
)

if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")

def serialze_hints_to_configs(hints: List[Hint]):
configs = []
for hint in hints:
config = self.TLHint.from_roller_hint(hint)
configs.append(config)
return configs

return serialze_hints_to_configs(roller_hints)

def apply_config(
self,
block_row_warps=2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
zeros_mode=self.zeros_mode)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@ class MatmulINT4DequantizeFineGrainedScheduler(MatmulDequantizeFineGrainedSchedu

def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}"
M = self.M
K = self.K // 2 # 2xint4 should be packed into one single int8
storage_dtype = "int8"
num_bits = self.num_bits * 2

# This is a hack to utilize tensor core
if isinstance(M, int) and M < 16:
M = 16

# INT4XINT2 is equal to int8xint4 with reduced shape
# Simple TIR Compute Expression
ir_module = matmul_dequantize_select_implementation(
Expand All @@ -56,7 +62,7 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
zeros_mode=self.zeros_mode)

roller_hints = get_roller_hints_from_func(
ir_module["main"],
ir_module,
arch,
topk,
tensorcore_only=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
from bitblas import tvm as tvm
from tvm import DataType
import tvm.tl.language as T
from typing import Optional
from typing import Optional, List
from bitblas.tl.utils import (
get_mma_micro_size, # noqa: F401
make_swizzle_layout, # noqa: F401
index_to_coordinates, # noqa: F401
)
from bitblas.base.arch import TileDevice
from bitblas.base.roller.hint import Hint
from bitblas.tl.macro_generator import (
INT4TensorCoreIntrinEmitterWithLadderTransform, # noqa: F401
)
from bitblas.ops.common import TransformKind # noqa: F401
from dataclasses import dataclass
from bitblas.base.utils import get_roller_hints_from_func
from bitblas.gpu.intrin.lop3 import get_lop3_intrin_group
from bitblas.ops.general_matmul.tirscript import (
matmul_dequantize_select_implementation,)
from bitblas.ops.general_matmul.tilelang.dequantize.ladder_weight_transform_tensorcore import (
MatmulDequantizeWeightPropagationScheduler,)

Expand All @@ -25,6 +30,60 @@
@dataclass
class MatmulINT4DequantizeWeightPropagationScheduler(MatmulDequantizeWeightPropagationScheduler):

def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
layout = f"{'t' if self.trans_A else 'n'}{'t' if self.trans_B else 'n'}"
M = self.M
K = self.K // 2 # 2xint4 should be packed into one single int8
storage_dtype = "int8"
num_bits = self.num_bits * 2

# This is a hack to utilize tensor core
if isinstance(M, int) and M < 16:
M = 16

# INT4XINT2 is equal to int8xint4 with reduced shape
# Simple TIR Compute Expression
ir_module = matmul_dequantize_select_implementation(
M=M,
N=self.N,
K=K,
in_dtype=storage_dtype,
out_dtype=self.out_dtype,
accum_dtype=self.accum_dtype,
layout=layout,
bit=num_bits,
storage_dtype=self.storage_dtype,
source_format=self.source_format,
with_scaling=self.with_scaling,
with_zeros=self.with_zeros,
group_size=self.group_size,
fast_decoding=self.fast_decoding,
with_bias=self.with_bias,
zeros_mode=self.zeros_mode)

roller_hints = get_roller_hints_from_func(
ir_module,
arch,
topk,
tensorcore_only=True,
allow_gemv=True,
)

if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")

for hint in roller_hints:
print(hint)

def serialze_hints_to_configs(hints: List[Hint]):
configs = []
for hint in hints:
config = self.TLHint.from_roller_hint(hint)
configs.append(config)
return configs

return serialze_hints_to_configs(roller_hints)

def apply_config(
self,
block_row_warps: Optional[int] = None,
Expand Down
11 changes: 4 additions & 7 deletions bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from bitblas.builder.wrapper import TIRWrapper, TLWrapper
from bitblas.builder.lib_generator import LibraryGenerator
from bitblas.common import MAX_ERROR_MESSAGE_LENGTH
from bitblas.utils import retrieve_func_from_module
from dataclasses import dataclass
import logging
import re
Expand Down Expand Up @@ -317,6 +318,7 @@ def apply_fast_tuning(
elif self.is_tilelang_backend():
# Finetune the schedule
tuning_configs = self.get_tl_tuning_config(topk=topk)
assert len(tuning_configs) > 0, "No tuning config found for this operator."
_, best = tl_apply_and_build(
func_or_scheduler, tuning_configs, arch=self.arch, parallel_build=parallel_build)
# Return the best Config as Hint
Expand Down Expand Up @@ -368,11 +370,9 @@ def hardware_aware_finetune(
assert (
len(scheduled_mod.get_global_vars()) == 1
), "The optimized module should only have one global variable for default schedule."
assert (
"main" in scheduled_mod
), "The optimized module should have a function named 'main' for default schedule."
default_kernal_name = self.kernel_name_generator.generate(best_hint)
func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name)
func = retrieve_func_from_module(scheduled_mod).with_attr("global_symbol",
default_kernal_name)
scheduled_ir_module = tvm.IRModule({default_kernal_name: func})
self._update_optimized_mod(scheduled_ir_module)

Expand Down Expand Up @@ -465,9 +465,6 @@ def forward(self, *args):
def __call__(self, *args: Any) -> Any:
return self.forward(*args)

def update_func(self, func: PrimFunc):
self.ir_module["main"] = func

def update_runtime_module(self, rt_mod=None, srcpath=None, libpath=None):
if rt_mod is not None:
self.rt_mod = rt_mod
Expand Down
Loading

0 comments on commit 04ffcc3

Please sign in to comment.