Skip to content

Commit

Permalink
mma intrin generation with meta programming
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 17, 2022
1 parent 5afb5f0 commit bf23fc5
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 489 deletions.
163 changes: 162 additions & 1 deletion python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for tensorization on NVIDIA GPU."""
from .. import Cast
from ..._ffi import register_func
from ...runtime import convert
from .. import TensorIntrin
Expand Down Expand Up @@ -46,6 +47,7 @@ def index_map_shared_16x16_to_ldmatrix_32x8_layout(i, j):
lift = convert

M_DIM = 16
N_DIM = 16
WARP_SIZE = 32
HALF_WARP = WARP_SIZE // 2
HALF_WARP_expr = lift(HALF_WARP)
Expand Down Expand Up @@ -81,7 +83,6 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed):
assert dtype == "int8"

if ldmatrix_col_major:
print("foo")
index_map = shared_32x16_to_ldmatrix_32x16_layout
shared_offset = (
lambda _, stride: stride
Expand Down Expand Up @@ -172,6 +173,148 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
return ldmatrix_desc, ldmatrix_impl


def get_mma_intrin(k_dim, out_dtype, transposed):
local_size = (M_DIM * k_dim) // WARP_SIZE
local_size_out = (M_DIM * N_DIM) // 32

index_map_C = shared_16x16_to_ldmatrix_32x8_layout

if k_dim == 16:
index_map_A = shared_16x16_to_ldmatrix_32x8_layout
index_map_B = shared_16x16_to_ldmatrix_32x8_layout
mma_prefix = "m16n8k16"
elif k_dim == 32 and transposed:
index_map_A = index_map_B = shared_16x32_to_ldmatrix_32x16_layout
mma_prefix = "m16n8k32"
elif k_dim == 32 and not transposed:
index_map_A = shared_16x32_to_ldmatrix_32x16_layout
index_map_B = shared_32x16_to_ldmatrix_32x16_layout
mma_prefix = "m16n8k32"
else:
assert False

out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32": "int32"}[out_dtype]

if out_dtype in ["float16", "float32"]:
in_dtype = "float16"
in_dtype_abbrv = "fp16"
else:
in_dtype = "int8"
in_dtype_abbrv = "int8"

def maybe_cast(v):
if out_dtype in ["float32", "int32"]:
return Cast(out_dtype, v)
return v

def maybe_swap(i, j):
if transposed:
return j, i
return i, j

@T.prim_func
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
)
B = T.match_buffer(
b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
)
C = T.match_buffer(
c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
)

with T.block("root"):
T.reads(
C[0:WARP_SIZE, 0:local_size_out],
A[0:WARP_SIZE, 0:local_size],
B[0:WARP_SIZE, 0:local_size],
)
T.writes(C[0:WARP_SIZE, 0:local_size_out])

for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
b_row_ind, b_col_ind = maybe_swap(k, j)

thread_id_C, local_id_C = index_map_C(i, j)
thread_id_A, local_id_A = index_map_A(i, k)
thread_id_B, local_id_B = index_map_B(b_row_ind, b_col_ind)

T.reads(
C[thread_id_C, local_id_C],
A[thread_id_A, local_id_A],
B[thread_id_B, local_id_B],
)
T.writes(C[thread_id_C, local_id_C])

C[thread_id_C, local_id_C] += maybe_cast(
A[thread_id_A, local_id_A]
) * maybe_cast(B[thread_id_B, local_id_B])

@T.prim_func
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
)
B = T.match_buffer(
b, (WARP_SIZE, local_size), in_dtype, align=128, offset_factor=16, scope="warp"
)
C = T.match_buffer(
c, (WARP_SIZE, local_size_out), out_dtype, align=128, offset_factor=16, scope="warp"
)

with T.block("root"):
T.reads(
C[0:WARP_SIZE, 0:local_size_out],
A[0:WARP_SIZE, 0:local_size],
B[0:WARP_SIZE, 0:local_size],
)
T.writes(C[0:WARP_SIZE, 0:local_size_out])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, WARP_SIZE)

T.evaluate(
T.ptx_mma(
mma_prefix,
"row",
"col",
in_dtype_abbrv,
in_dtype_abbrv,
out_dtype_abbrv,
A.data,
A.elem_offset + tx * lift(local_size),
B.data,
B.elem_offset + tx * lift(local_size),
C.data,
C.elem_offset + tx * lift(local_size_out),
False,
dtype=out_dtype,
)
)

T.evaluate(
T.ptx_mma(
mma_prefix,
"row",
"col",
in_dtype_abbrv,
in_dtype_abbrv,
out_dtype_abbrv,
A.data,
A.elem_offset + tx * lift(local_size),
B.data,
B.elem_offset + tx * lift(local_size) + lift(local_size) // 2,
C.data,
C.elem_offset + tx * lift(local_size_out) + lift(local_size_out) // 2,
False,
dtype=out_dtype,
)
)

return mma_sync_desc, mma_sync_impl


LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, "float16", False, False))

Expand All @@ -191,3 +334,21 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:

LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans"
TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", True, True))

MMA_f16f16f32_INTRIN = "mma_f16f16f32"
TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False))

MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans"
TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16, "float32", True))

MMA_f16f16f16_INTRIN = "mma_f16f16f16"
TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False))

MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans"
TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16, "float16", True))

MMA_i8i8i32_INTRIN = "mma_i8i8i32"
TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False))

MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans"
TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", True))
82 changes: 2 additions & 80 deletions tests/python/unittest/test_mma_16x8x16_4k_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,90 +7,13 @@
from tvm.tir.tensor_intrin.cuda import (
LDMATRIX_16x16_A_INTRIN,
LDMATRIX_16x16_B_INTRIN,
MMA_f16f16f32_INTRIN,
shared_16x16_to_ldmatrix_32x8_layout,
)
import tvm.testing
import numpy as np


@T.prim_func
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
T.writes(C[0:32, 0:8])
for i, j, k in T.grid(16, 16, 16):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
thread_id_C, local_id_C = shared_16x16_to_ldmatrix_32x8_layout(i, j)
thread_id_A, local_id_A = shared_16x16_to_ldmatrix_32x8_layout(i, k)
thread_id_B, local_id_B = shared_16x16_to_ldmatrix_32x8_layout(k, j)

T.reads(
C[thread_id_C, local_id_C],
A[thread_id_A, local_id_A],
B[thread_id_B, local_id_B],
)
T.writes(C[thread_id_C, local_id_C])
C[thread_id_C, local_id_C] += T.cast(
A[thread_id_A, local_id_A], "float32"
) * T.cast(B[thread_id_B, local_id_B], "float32")


@T.prim_func
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
B = T.match_buffer(b, (32, 8), "float16", align=128, offset_factor=16, scope="warp")
C = T.match_buffer(c, (32, 8), "float32", align=128, offset_factor=16, scope="warp")

with T.block("root"):
T.reads(C[0:32, 0:8], A[0:32, 0:8], B[0:32, 0:8])
T.writes(C[0:32, 0:8])
tx = T.env_thread("threadIdx.x")
T.launch_thread(tx, 32)

T.evaluate(
T.ptx_mma(
"m16n8k16",
"row",
"col",
"fp16",
"fp16",
"fp32",
A.data,
A.elem_offset + tx * 8,
B.data,
B.elem_offset + tx * 8,
C.data,
C.elem_offset + tx * 8,
False,
dtype="float32",
)
)

T.evaluate(
T.ptx_mma(
"m16n8k16",
"row",
"col",
"fp16",
"fp16",
"fp32",
A.data,
A.elem_offset + tx * 8,
B.data,
B.elem_offset + tx * 8 + 4,
C.data,
C.elem_offset + tx * 8 + 4,
False,
dtype="float32",
)
)


@T.prim_func
def mma_store_desc(a: T.handle, c: T.handle) -> None:
C_warp = T.match_buffer(a, [32, 8], dtype="float32", scope="warp")
Expand Down Expand Up @@ -160,7 +83,6 @@ def mma_fill_impl(a: T.handle) -> None:
T.evaluate(T.mma_fill(8, C_warp.data, C_warp.elem_offset, dtype="float32"))


tir.TensorIntrin.register("mma_sync", mma_sync_desc, mma_sync_impl)
tir.TensorIntrin.register("mma_store", mma_store_desc, mma_store_impl)
tir.TensorIntrin.register("mma_fill", mma_fill_desc, mma_fill_impl)

Expand Down Expand Up @@ -291,7 +213,7 @@ def index_map(i, j):

sch.tensorize(loop_a, LDMATRIX_16x16_A_INTRIN)
sch.tensorize(loop_b, LDMATRIX_16x16_B_INTRIN)
sch.tensorize(sch.get_loops(block_inner)[-3], "mma_sync")
sch.tensorize(sch.get_loops(block_inner)[-3], MMA_f16f16f32_INTRIN)
sch.tensorize(sch.get_loops(block_init_c)[-2], "mma_fill")
sch.tensorize(sch.get_loops(C_warp)[-2], "mma_store")

Expand Down
Loading

0 comments on commit bf23fc5

Please sign in to comment.