Skip to content

Commit

Permalink
fix for int8 gemm (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 authored Sep 18, 2024
1 parent 916a54c commit ce7466c
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,15 @@ def check_sm_version(arch: str) -> int:
sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1

def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool:
def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV,
target: Target) -> Union[bool, Dict]:
tags: Dict[str, Union[List[int], int]] = {}
block_stmt = sch.get(block)

# Nvidia Only Support Tensor Core for
# devices greater than 70.
if check_sm_version(target.arch) < 70:
return False
# analysis tensorcore axis
# todo(lei): maybe we can remove this in the future
(write_buffer_region,) = block_stmt.writes
Expand Down Expand Up @@ -612,6 +617,11 @@ def check_last_trait(region: List[Range]):
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
intrin_info["in_dtype"] = in_dtype
intrin_info["out_dtype"] = out_dtype

if 70 <= check_sm_version(target.arch) < 80 and out_dtype == "int32":
# INT32 Accum TensorCore only supports SM Version > 32.
return False

# if the last dimension is reduce axis, the B is transposed
intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region)
if func.attrs is not None and "input_transform_kind" in func.attrs:
Expand Down Expand Up @@ -666,6 +676,7 @@ def check_last_trait(region: List[Range]):

block_stmt = sch.get(main_block)

# 16 for 16 bits tensor core while 32 for 8bits tensorcore.
minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
# the batch dimension is not taken into consideration.
extent = block_stmt.iter_vars[1].dom.extent
Expand Down

0 comments on commit ce7466c

Please sign in to comment.