-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DISCUSS] Tensorization of Warp Level Primitives #371
Comments
def matmul():
for i0, j0, k0 in grid(8, 8, 1):
CC[i0, j0]+= AA[i0, k0] * BB[j0, k0]
def func():
for i, j, k in grid(128, 128, 128):
C[i, j]+= A[i, k] * B[j, k]
def func_step0():
for i1, j1, k1 in grid(16, 16, 128):
for i0, j0, k0 in grid(8, 8, 1):
C[i, j]+= A[i, k] * B[j, k]
def func_step1():
for i1, j1, k1 in grid(16, 16, 128):
for ia, ka in grid(8, 1):
AA[ia, ka] = A[i1 * 8 + ia, k1 + ka]
for ib, kb in grid(8, 1):
BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]
with tensorized:
for i0, j0, k0 in grid(8, 8, 1):
CC[i, j]+= AA[i, k] * BB[j, k]
for ic, jc in grid(8, 8):
C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]
def func_step2():
for i1, j1, k1 in grid(16, 16, 128):
for ia, ka in grid(8, 1):
Awarp[ia % 2, ia // 4] = A[i1 * 8 + ia, k1 + ka]
for ib, kb in grid(8, 1):
BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]
for ia, ka in grid(8, 1):
AA[ia, ka] = A[ia, ka]
with tensorized:
for i0, j0, k0 in grid(8, 8, 1):
CC[i, j]+= AA[i, k] * BB[j, k]
for ic, jc in grid(8, 8):
C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc]
def func_step3():
for i1, j1, k1 in grid(16, 16, 128):
for i in grid(2):
for wid in thread_binding("warpIndex", 4):
Awarp[i, wid] = A[i1 * 8 + wid*2 +i, k1 + ka]
for ib, kb in grid(8, 1):
BB[ib, kb] = B[j1 * 8 + ib, k1 + ka]
for ia, ka in grid(8, 1):
AA[ia, ka] = A[ia, ka]
with tensorized:
for i0, j0, k0 in grid(8, 8, 1):
CC[i, j]+= AA[i, k] * BB[j, k]
for ic, jc in grid(8, 8):
C[i0 * 8 + ic, j0* 8 + jc] = CC[ic, jc] |
I have thought about a new proposal for TensorCore. Would like to have some discussion :) Main Idea: wmma load/store changes data layout.Currently, we write load/store intrin desc like following codes: with tir.block([16, 16], "store") as [vi, vj]:
AA[vi, vj] = A[vi, vj] However, the true behavior of load/store is that(assume that we have a 16*16 warp op): with tir.block([16, 16], "store") as [vi, vj]:
AA[vi // 16, vj // 16, vi % 16, vj % 16] = A[vi, vj] Hardware behaviorThe warp fragment memory is somehow continuous (at least at CUDA level). With wmma API, we declare a warp memory using Cache_read/write with re-layout supportTo support this memory layout transformation during the schedule, we need to introduce a new primitive. AA = s.cache_read(A, lambda i, j: i // 16, j // 16, i % 16, j % 16) And the generated IR is with tir.block([n, m]) as [i, j]:
AA[i // 16, j // 16, i % 16, j % 16] = A[i, j] Benefits
|
I have elaborated a bit the workflow: @tvm.script.tir
def intrin_desc(a: ty.handle, b: ty.handle, c: ty.handle):
# desc in like valilla matmul, with special buffer scope
A = tir.match_buffer(a, shape=(16, 16), scope='warp.layoutA')
B = tir.match_buffer(b, shape=(16, 16), scope='warp.layoutB')
C = tir.match_buffer(c, shape=(16, 16), scope='warp.layoutC')
with block('root', [16, 16, tir.reduce_axis(16)]) as [vi, vj, vk]:
tir.bind(vi, 0)
tir.bind(vj, 0)
tir.bind(vk, 0)
for i, j, k in tir.grid(16, 16, 16):
with block('C', [16, 16, tir.reduce_axis(16)]) as [vii, vji, vki]:
tir.bind(vii, vi + i)
tir.bind(vji, vj + j)
tir.bind(vki, vk + k)
C[vii, vki] += A[vii, vki] * B[vji, vki]
@tvm.script.tir
def intrin_impl(a: ty.handle, b: ty.handle, c: ty.handle):
# calling warp level intrinsic
A = tir.match_buffer(a, shape=(16, 16), scope='warp.layoutA')
B = tir.match_buffer(b, shape=(16, 16), scope='warp.layoutB')
C = tir.match_buffer(c, shape=(16, 16), scope='warp.layoutC')
with block('C', [16, 16, tir.reduce_axis(16)]) as [vii, vji, vki]:
tir.mma_16x16x16(A, B, C, A_frag_index, B_frag_index, C_frag_index) # fragment indices are computed based on elem_offset, such as A.elem_offset // 256
def schedule_fn(sch):
# split i, j, k and reorder ...
sch.reorder(i0, j0, k0, i1, j1, k1)
AA = sch.cache_read(A, 0, 'warp.layoutA')
BB = sch.cache_read(B, 0, 'warp.layoutB')
CC = sch.cache_write(C, 0, 'warp.layoutC')
sch.compute_at(CC, k0)
sch.compute_at(AA, k0)
sch.compute_at(BB, k0)
sch.tensorize(CC, i1, tensor_intrin) Special layout can be lowered during buffer flatten. Intrinsic |
AMD and Nvidia's MFMA(matrix multiplication operators) operates on the warp level. This creates some interesting challenges for tensorization, semantics checks and tensorization infra. This is a discuss issue that tries to capture some of these questions.
Case Study Example, AMD's mfma32x31x1_f32 instruction
In AMD's case, the GPU have warp size = 64. e.g. the operations are done collectively by 64 threads, where the input and outputs are distributed along the registers of each thread. To make the presentation simple, we will use the following notations
Will get lowered into the following code in the thread level view.
The AMD's mfma32x31x1_f32 is a batched matmul instruction that performs two matrix outer products, to see what happens, the instruction is equivalent to
Logical Semantics
This is a batch matrix multiplication that divides the warp data into 2 of 32x32 groups and perform the matmul
In order to implement the above logical semantics, the
C[2, 32, 32]
,B[2, 32]
andA[2, 32]
are stored as special registers in warp memory, using the following rule (<=> means the memory map relation, wid is the warp index):Namely, the data are of A, B and C needs to be layed out in a special way in the warp level memory, which in term maps to the corresponding registers(by removing the wid component.
The actual gpu code looks like follows(use a simple example to illustrate the intrinsics)
The above kernel performs
In order to perform the matrix multiplication(tensorization) we need to perform the following steps:
Using BatchMatMul Intrinsic to Implement Matmul
It is possible to use Batch matmul intrinsic above to implement matmul(by replicating one side of the element). The logic is as follows (defining BB, AA, CC as the inputs and outputs of the matmul):
Then we have the following relationship:
Which is exactly a 64x32 matmul
Challenges and Questions
We can find the following challenges that arises when tensorizing a wrap level primitives.
It would be useful to discuss possible ways to solve these challenges, for example:
The text was updated successfully, but these errors were encountered: