Skip to content

Commit

Permalink
[TL] Inject Storage Sync Scope Automatically for TL (#177)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* buf fix for matrix support

* lint fix

* dispatch tensor core based on shapes

* update install commands

* import scripts

* remove shared mem hack

* revert change for swizzling

* bug fix

* tl examples

* Enhance Swizzle

* lint fix

* test fix

* lint fix

* optimize layout

* update tl utils.

* macro optimization

* test fix

* gemm_ss

* doc fix

* lint fix

* lint fix

* remove debug print

* remove debug print
  • Loading branch information
LeiWang1999 authored Sep 6, 2024
1 parent b9fab25 commit 11649f0
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
22 changes: 11 additions & 11 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def MMA(inst, A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
T.ptx_mma(
inst.accum_dtype,
"m16n8k16",
inst.mma_prefix,
"row",
"col",
inst.a_dtype_abbrv,
Expand All @@ -114,7 +114,7 @@ def MMA(inst, A_local_buf, B_local_buf, C_local_buf):

T.ptx_mma(
inst.accum_dtype,
"m16n8k16",
inst.mma_prefix,
"row",
"col",
inst.a_dtype_abbrv,
Expand Down Expand Up @@ -142,19 +142,18 @@ def LDMATRIX_A(
stride = inst.chunk
tx = thread_bindings % inst.WARP_SIZE
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps
# self.ty = (thread_bindings // warp_size) % block_row_warps
# self.tz = thread_bindings // (warp_size * block_row_warps)

for i in T.serial(inst.warp_rows):
T.ptx_ldmatrix(
"float16",
inst.a_dtype,
T.bool(False),
4,
".b16",
A_local_buf.data,
i * inst.local_size_a,
T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x,
ki * inst.micro_size_k,]),
get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, False),
get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, inst.a_transposed),
)

@staticmethod
Expand All @@ -171,15 +170,15 @@ def LDMATRIX_B(
tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)
for j in T.serial(inst.warp_cols):
T.ptx_ldmatrix(
"float16",
inst.b_dtype,
T.bool(False), # TODO(lei): should be optimized
4,
".b16",
B_local_buf.data,
j * inst.local_size_b,
T.address_of(B_shared_buf[tz * inst.warp_col_tiles + j * inst.micro_size_y,
ki * inst.micro_size_k,]),
get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, True),
get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, inst.b_transposed),
)

# STS
Expand All @@ -203,13 +202,14 @@ def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings):
@staticmethod
@T.macro
def GEMM_SS(inst, A_shared_buf, B_shared_buf, C_local_buf, thread_bindings):
A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size),
# TODO(lei): alloc_buffer within the macro is not supported yet.
A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size_a),
inst.a_dtype,
scope="local")
B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size),
B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size_b),
inst.b_dtype,
scope="local")
for ki in T.serial(0, (inst.block_K // inst.micro_size_k)):
for ki in T.serial(0, (inst.chunk // inst.micro_size_k)):
inst.LDMATRIX_A(
inst,
A_local_buf,
Expand Down
2 changes: 1 addition & 1 deletion bitblas/tl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
# permutation on 4 banks, each bank has 32 bits
bank_elems = BANK_SIZE_BYTES // dtype.bits
new_col_idx_outer = None
print(f"coalescent_bits: {coalescent_bits}")

if coalescent_bits % 1024 == 0:
# Use 8 * 8 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
Expand Down
1 change: 1 addition & 0 deletions docs/ExtendOperatorsWithDSL.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
### Using BitBLAS from DSL
```python
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags
from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy
from bitblas.base.arch import CUDA
from bitblas.base.utils import apply_and_build
Expand Down

0 comments on commit 11649f0

Please sign in to comment.