You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@tvm.script.tir
def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
def main():
mod = tvm.script.create_module({'main': matmul})
s = tir.Schedule(mod)
C = s.get_block('C')
i, j, k = s.get_axes(C)
i0, i1 = s.split(i, factor=32)
j0, j1 = s.split(j, factor=32)
k0, k1 = s.split(k, factor=1)
s.reorder(i0, j0, k0, i1, j1, k1)
s.blockize(i1)
print(tvm.script.asscript(s.mod['main']))
main()
Result:
@tvm.script.tir
def func(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
B = tir.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1)
# body
with tir.block([], "root"):
tir.reads([])
tir.writes([])
for i0_outer, i1_outer, i2_outer in tir.grid(4, 4, 128):
with tir.block([4, 4, tir.reduce_axis(0, 128)], "blockized_C") as [vio, vjo, vk]:
tir.bind(vio, i0_outer)
tir.bind(vjo, i1_outer)
tir.bind(vk, i2_outer)
tir.reads([C[(vio*32):((vio*32) + 32), (vjo*32):((vjo*32) + 32)], A[(vio*32):((vio*32) + 32), vk], B[(vjo*32):((vjo*32) + 32), vk]])
tir.writes([C[(vio*32):((vio*32) + 32), (vjo*32):((vjo*32) + 32)]])
with tir.init():
for i0_inner, i1_inner in tir.grid(32, 32):
with tir.block([128, 128], "C_init") as [vi_init, vj_init]:
tir.bind(vi_init, ((vio*32) + i0_inner))
tir.bind(vj_init, ((vjo*32) + i1_inner))
tir.reads([])
tir.writes([C[vi_init, vj_init]])
C[vi_init, vj_init] = tir.float32(0)
for i0_inner_1, i1_inner_1, i2_inner in tir.grid(32, 32, 1):
with tir.block([128, 128], "C") as [vi, vj]:
tir.bind(vi, ((vio*32) + i0_inner_1))
tir.bind(vj, ((vjo*32) + i1_inner_1))
tir.reads([C[vi, vj], A[vi, vk], B[vj, vk]])
tir.writes([C[vi, vj]])
C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk]))
The result contains unit loop i2_inner, but block C only has two iter vars, which is inconsistent with number of outer loops.
What's the opinion in using outer block iter vars directly inside the inner block C?
Proposal 1: Add an iter var to block C that bound to the unit loop:
See the following example
Result:
The result contains unit loop
i2_inner
, but blockC
only has two iter vars, which is inconsistent with number of outer loops.What's the opinion in using outer block iter vars directly inside the inner block
C
?i2_inner
but not change blockC
The text was updated successfully, but these errors were encountered: