Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Expose Memory Copy-Related PTX Builtins #12611

Merged
merged 2 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
tvm_bmma_sync,
tvm_fill_fragment,
)
from .op import ptx_ldmatrix, ptx_cp_async, ptx_commit_group, ptx_wait_group
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
Expand Down
111 changes: 111 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,117 @@ def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
)


def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
"""TVM intrinsic for ptx load matrix from shared memory
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix

Parameters
----------
dtype : str
The data type of the result.

trans : bool
The matrix is loaded in column-major format.

num : IntImm
The number of matrices.

type : Literal[".b16"]
The data type of the matrices.

local_ptr : Var
The local pointer variable.

local_offset : Expr
The offset of local pointer.

smem_ptr : Var
The shared memory pointer variable.

smem_offset : Expr
The offset of shared memort pointer.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
"tir.ptx_ldmatrix",
trans,
num,
type,
local_ptr,
local_offset,
smem_ptr,
smem_offset,
)


def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
"""TVM intrinsic for ptx async copy from global to shared memory
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async

Parameters
----------
dtype : str
The data type of the result.

shared_ptr : Var
The shared memory pointer variable.

shared_offset : Expr
The offset of shared memory pointer.

global_ptr : Var
The global memory pointer variable.

global_offset : Expr
The offset of global memory pointer.

bytes : int
The data size to copy.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
)


def ptx_commit_group():
"""TVM intrinsic for ptx async copy commit
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_commit_group")


def ptx_wait_group(num):
"""TVM intrinsic for ptx async copy wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group

Parameters
----------
num : int
The number of the most recent uncommitted pending cp.async groups to wait.

Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_group", num)


def vectorlow(dtype, vec):
"""Get the low level half of the vector

Expand Down
40 changes: 35 additions & 5 deletions tests/python/unittest/test_tir_op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,32 @@ def test_tir_op_tvm_fill_fragment():
assert expr.op.name == "tir.tvm_fill_fragment"


def test_op_ptx_ldmatrix():
buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared")
buffer_local = tir.decl_buffer([8], "float16", scope="local")
expr = tir.ptx_ldmatrix(
"float16", False, 4, ".b16", buffer_local.data, 0, buffer_shared.data, 0
)
assert expr.op.name == "tir.ptx_ldmatrix"


def test_op_ptx_cp_async():
buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared")
buffer_local = tir.decl_buffer([8], "float16", scope="local")
expr = tir.ptx_cp_async("float16", buffer_shared.data, 0, buffer_local.data, 0, 16)
assert expr.op.name == "tir.ptx_cp_async"


def test_op_ptx_commit_group():
expr = tir.ptx_commit_group()
assert expr.op.name == "tir.ptx_commit_group"


def test_op_ptx_wait_group():
expr = tir.ptx_wait_group(8)
assert expr.op.name == "tir.ptx_wait_group"


def test_tir_op_vectorlow():
buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="int8x16")
Expand Down Expand Up @@ -203,11 +229,15 @@ def test_tir_op_TVMBackendFreeWorkspace():
test_tir_op_type_annotation()
test_tir_op_tvm_access_ptr()
test_tir_op_tvm_throw_last_error()
test_tir_op_tvm_load_matrix_sync(),
test_tir_op_tvm_store_matrix_sync(),
test_tir_op_tvm_mma_sync(),
test_tir_op_tvm_bmma_sync(),
test_tir_op_tvm_fill_fragment(),
test_tir_op_tvm_load_matrix_sync()
cyx-6 marked this conversation as resolved.
Show resolved Hide resolved
test_tir_op_tvm_store_matrix_sync()
test_tir_op_tvm_mma_sync()
test_tir_op_tvm_bmma_sync()
test_tir_op_tvm_fill_fragment()
test_op_ptx_ldmatrix()
test_op_ptx_cp_async()
test_op_ptx_commit_group()
test_op_ptx_wait_group()
test_tir_op_vectorlow()
test_tir_op_vectorhigh()
test_tir_op_vectorcombine()
Expand Down