Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Expose WMMA-related TensorCore builtins #12589

Merged
merged 2 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@
from .op import tvm_tuple, tvm_struct_get, tvm_struct_set
from .op import address_of, lookup_param, assume, undef
from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error
from .op import (
tvm_load_matrix_sync,
tvm_store_matrix_sync,
tvm_mma_sync,
tvm_bmma_sync,
tvm_fill_fragment,
)
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
Expand Down
236 changes: 236 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,242 @@ def tvm_throw_last_error():
return call_intrin("handle", "tir.tvm_throw_last_error")


def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
"""TVM intrinsic for tensor core load operators

Parameters
----------
fragment : Var
The wmma fragment.

m : UIntImm
The shape of wmma fragment.

n : UIntImm
The shape of wmma fragment.

k : UIntImm
The shape of wmma fragment.

index : Expr
The fragment index.

buffer_ptr : Expr
The fragment buffer pointer.

stride : Expr
The fragment stride.

layout : StringImm
cyx-6 marked this conversation as resolved.
Show resolved Hide resolved
The fragment layout.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_load_matrix_sync",
fragment,
m,
n,
k,
index,
buffer_ptr,
stride,
layout,
)


def tvm_mma_sync(
fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
):
"""TVM intrinsic for tensor core mma_sync operators

Parameters
----------
fragment_d : Var
The wmma fragment_d.

index_d : Expr
The fragment_d index.

fragment_a : Var
The wmma fragment_a.

index_a : Expr
The fragment_a index.

fragment_b : Var
The wmma fragment_b.

index_b : Expr
The fragment_b index.

fragment_c : Var
The wmma fragment_c.

index_c : Expr
The fragment_c index.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_mma_sync",
fragment_d,
index_d,
fragment_a,
index_a,
fragment_b,
index_b,
fragment_c,
index_c,
)


def tvm_bmma_sync(
fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c
):
"""TVM intrinsic for tensor core bmma_sync operators

Parameters
----------
fragment_d : Var
The bwmma fragment_d.

index_d : Expr
The fragment_d index.

fragment_a : Var
The bwmma fragment_a.

index_a : Expr
The fragment_a index.

fragment_b : Var
The bwmma fragment_b.

index_b : Expr
The fragment_b index.

fragment_c : Var
The bwmma fragment_c.

index_c : Expr
The fragment_c index.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_bmma_sync",
fragment_d,
index_d,
fragment_a,
index_a,
fragment_b,
index_b,
fragment_c,
index_c,
)


def tvm_fill_fragment(fragment, m, n, k, index, value):
"""TVM intrinsic for tensor core fill_fragment operators

Parameters
----------
fragment : Var
The wmma fragment

m : UIntImm
The shape of wmma fragment.

n : UIntImm
The shape of wmma fragment.

k : UIntImm
The shape of wmma fragment.

index : Expr
The fragment index.

value : Expr
The value to be filled in fragment.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_fill_fragment",
fragment,
m,
n,
k,
index,
value,
)


def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
"""TVM intrinsic for tensor core store operators

Parameters
----------
fragment : Var
The wmma fragment.

m : UIntImm
The shape of wmma fragment.

n : UIntImm
The shape of wmma fragment.

k : UIntImm
The shape of wmma fragment.

index : Expr
The fragment index.

buffer_ptr : Expr
The fragment buffer pointer.

stride : Expr
The fragment stride.

layout : StringImm
The fragment layout.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
"handle",
"tir.tvm_store_matrix_sync",
fragment,
m,
n,
k,
index,
buffer_ptr,
stride,
layout,
)


def vectorlow(dtype, vec):
"""Get the low level half of the vector

Expand Down
43 changes: 43 additions & 0 deletions tests/python/unittest/test_tir_op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,44 @@ def test_tir_op_tvm_throw_last_error():
assert expr.op.name == "tir.tvm_throw_last_error"


def test_tir_op_tvm_load_matrix_sync():
buffer = tir.decl_buffer((16, 16), "float32")
x = tir.Var("x", "handle")
expr = tir.tvm_load_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major")
assert expr.op.name == "tir.tvm_load_matrix_sync"


def test_tir_op_tvm_store_matrix_sync():
buffer = tir.decl_buffer((16, 16), "float32")
x = tir.Var("x", "handle")
expr = tir.tvm_store_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major")
assert expr.op.name == "tir.tvm_store_matrix_sync"


def test_tir_op_tvm_mma_sync():
buffer_0 = tir.decl_buffer((16, 16), "float32")
buffer_1 = tir.decl_buffer((16, 16), "float32")
buffer_2 = tir.decl_buffer((16, 16), "float32")
buffer_3 = tir.decl_buffer((16, 16), "float32")
expr = tir.tvm_mma_sync(buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0)
assert expr.op.name == "tir.tvm_mma_sync"


def test_tir_op_tvm_bmma_sync():
buffer_0 = tir.decl_buffer((16, 16), "float32")
buffer_1 = tir.decl_buffer((16, 16), "float32")
buffer_2 = tir.decl_buffer((16, 16), "float32")
buffer_3 = tir.decl_buffer((16, 16), "float32")
expr = tir.tvm_bmma_sync(buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0)
assert expr.op.name == "tir.tvm_bmma_sync"


def test_tir_op_tvm_fill_fragment():
buffer = tir.decl_buffer((16, 16), "float32")
expr = tir.tvm_fill_fragment(buffer.data, 16, 16, 16, 0, 0)
assert expr.op.name == "tir.tvm_fill_fragment"


def test_tir_op_vectorlow():
buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="int8x16")
Expand Down Expand Up @@ -165,6 +203,11 @@ def test_tir_op_TVMBackendFreeWorkspace():
test_tir_op_type_annotation()
test_tir_op_tvm_access_ptr()
test_tir_op_tvm_throw_last_error()
test_tir_op_tvm_load_matrix_sync(),
test_tir_op_tvm_store_matrix_sync(),
test_tir_op_tvm_mma_sync(),
test_tir_op_tvm_bmma_sync(),
test_tir_op_tvm_fill_fragment(),
test_tir_op_vectorlow()
test_tir_op_vectorhigh()
test_tir_op_vectorcombine()
Expand Down