Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Support tensorization using ldmatrix + MMA #11355

Merged
merged 4 commits into from
May 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,33 @@ TVM_DLL const Op& ptx_cp_async();
TVM_DLL const Op& ptx_commit_group();
TVM_DLL const Op& ptx_wait_group();

/*!
* \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer.
* For example, if each thread in a warp of size 32 has 4 elements from the result of
* m16xn8xk16 MMA in its registers, this intrinsic can be used to store the result in a
* 16x8 region in shared or global memory.
*
* There is no real PTX instruction that does that, but we want to hide details of
* complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g.
* LowerWarpMemory).
*
* void mma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var dst_stride);
*/
TVM_DLL const Op& mma_store();

/*!
* \brief tvm intrinsic for zero-initalizing an MMA accumulation registor.
* For example, if each thread in a warp of size 32 has 8 elements from the A matrix in
* m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its
* 4 accumulation registers.
*
* There is no real PTX instruction that does that, but we introduce this intrinsic for the
* same reason as mma_store above.
*
* void mma_fill(IntImm local_size, Var local_ptr, Expr offset);
*/
TVM_DLL const Op& mma_fill();

// TODO(tvm-team) replace the usage of the vector operations by Shuffle.
/*!
* \brief Get the high level half of the vector
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/tensor_intrin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .arm_cpu import *
from .dot_product_common import *
from .rocm import *
from .cuda import *
Loading