Replies: 1 comment 3 replies
-
@krzysz00 , @sjw36 , @jerryyin This is a very first draft of this task, but I thought to get the ball rolling. |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Goal
Adding WMMA support to rocMLIr
What is WMMA
WMMA is the new intrinsics available on RDNA3. A simpler explanation (and hello world HIP application) can be found here:
https://gpuopen.com/learn/wmma_on_rdna3/
Different WMMA flavours
There are in total 12 versions of the intrinsics:
__builtin_amdgcn_wmma_f32_16x16x16_f16_{w32,w64}
__builtin_amdgcn_wmma_f32_16x16x16_bf16_{w32,w64}
__builtin_amdgcn_wmma_f16_16x16x16_f16_{w32,w64}
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_{w32,w64}
__builtin_amdgcn_wmma_i32_16x16x16_iu8_{w32,w64}
__builtin_amdgcn_wmma_i32_16x16x16_iu4_{w32,w64}
The output is always stored in 8 VGPR for the w32 version and in 4 VGPRs for the w64 version. Because of that, the
fp16
andbf16
versiosn have anOPSEL
that decides if to store the data in the high or low section of each VGPR.This means that the return type of the different intrinsics in LLVM can be:
8xfloat
,8xi32
,16xf16
,16xbf16
for the w32 version4xfloat
,4xi32
,8xf16
,4xbf16
for the w32 versionAdd support in the backend (
AMDGPU.td
andROCDLOps.td
)I came up with the following
AMDGPU
operation:Notes:
f16->f16
(andbf16->bf16
) wmma outputs a logical8x2
or16x2
vector, the translation of a 2D vector in LLVM results in anllvm.array
which is hard to flatten out. So the 2D -> 1D conversion logic needs to happen before we issue the wmma operationint8
. Whileui8
andsi8
is supported in MLIR, it is not supported in LLVM. This means that when the conversion MLIR->LLVM happens,unrealized_cast
s appear. So the best way to go is to let the user ofwmma
specify if theint8
input is signed or not.vector<4xi32>
instead of avector<16xi8>
. I implemented this conversion inside theAMDGPUToROCDL
conversionThe implementation is available at: #1035
Add support in the front-end (
GridwiseGemmToBlockWise
,BlockwiseGemmToThreadWise
,ThreadWiseGemLowering
)This depends on the previous task and can be done later.
Beta Was this translation helpful? Give feedback.
All reactions