diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index f68644305b90..bdce15fc0bd7 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -26,6 +26,7 @@ from tvm.tir import IterVar, PrimExpr, Var from tvm.tir.analysis import undefined_vars from tvm.tir.schedule.schedule import BlockRV +from tvm.script import tir as T from ..base import analysis, BlockInfo, IterInfo from .base import GPUScheduleRule @@ -942,14 +943,14 @@ def get_configs(self, target: Target) -> Config: ): return Matmul.Config( block_size_x=32, - block_size_y=8, + block_size_y=4, vthread_x=1, vthread_y=1, micro_size_x=8, micro_size_y=2, micro_size_k=16, vector_size=8, - unroll=4, + unroll=16, use_shared=False, storage_align=False, inner_x=True, @@ -1144,7 +1145,7 @@ def get_max_factor(n, factors): if not ( isinstance(sch.get(n).extent, tir.IntImm) and isinstance(sch.get(mb).extent, tir.IntImm) - and isinstance(sch.get(ms).extent, tir.Var) + and not isinstance(sch.get(ms).extent, tir.IntImm) ): return None @@ -1154,6 +1155,7 @@ def get_max_factor(n, factors): config.vector_size, config.unroll, ) + VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4, 8]), VecSize) dequant_block = None matmul_block = reduction_block @@ -1166,61 +1168,73 @@ def get_max_factor(n, factors): elif blk is not matmul_block: sch.compute_inline(blk) - m = sch.fuse(mb, ms) - - sch.pad_einsum(matmul_block, [1, Threads_Y * Unroll_M, Threads_X * VecSize, 1]) - - rmat_block, wmat_block = ( + block = sch.reindex(reduction_block, ("read", 0)) + sch.pad_einsum(reduction_block, [1, Unroll_M, 1, 1]) + sch.compute_inline(block) + trans_block, matmul_reindex = ( sch.get_producers(matmul_block)[0], sch.get_consumers(matmul_block)[0], ) - mo, mi, mu = sch.split(m, [None, Threads_Y, Unroll_M]) - no, ni, nv = sch.split(n, [None, Threads_X, VecSize]) - k0, k1, k2, k3 = sch.split(k, [None, (Threads_X * VecSize) // 32, 4, 8]) - sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv) - sch.compute_at(rmat_block, k0) - if dequant_block is not None: - sch.compute_at(dequant_block, k3) - sch.reverse_compute_at(wmat_block, mi) - sch.set_scope(rmat_block, 0, "shared") - sch.set_scope(matmul_block, 0, "local") + if epilogue_block is not None: + sch.compute_inline(matmul_reindex) + matmul_reindex = epilogue_block - if dequant_block is not None: - sch.set_scope(dequant_block, 0, "local") + sch.transform_layout( + trans_block, + ("write", 0), + T.index_map(lambda i0, i1, i2: (i0, i1 // Unroll_M, i2, i1 % Unroll_M)), + ) - sch.bind(mo, "blockIdx.y") - sch.bind(no, "blockIdx.x") - sch.bind(mi, "threadIdx.y") - sch.bind(ni, "threadIdx.x") - sch.vectorize(sch.get_loops(matmul_block)[-1]) + # transpose block schedules + # sch.set_scope(trans_block, 0, "global.texture-1d") + tb, tn, tk = sch.get_loops(trans_block) + tbx, ttx = sch.split(tk, [None, Threads_X]) + tby, tty, tc = sch.split(tn, [None, Threads_Y, Unroll_M]) + sch.bind(tb, "blockIdx.z") + sch.bind(tby, "blockIdx.y") + sch.bind(tbx, "blockIdx.x") + sch.bind(tty, "threadIdx.y") + sch.bind(ttx, "threadIdx.x") + sch.reorder(tb, tby, tbx, tty, ttx, tc) + sch.vectorize(tc) + + mb, ms, n, k = sch.get_loops(matmul_block) + m = sch.fuse(mb, ms) + bx, tx, vec = sch.split(n, [None, Threads_X, VecSize]) + by, ty, unr = sch.split(m, [None, Threads_Y, Unroll_M]) + k1, k2, k3 = sch.split(k, [None, 4, 8]) + sch.reorder(bx, by, tx, ty, k1, k2, k3, unr, vec) + sch.set_scope(matmul_block, 0, "local") if dequant_block is not None: - sch.vectorize(sch.get_loops(dequant_block)[-1]) + sch.compute_at(dequant_block, k3) + sch.set_scope(dequant_block, 0, "local") + sch.bind(by, "blockIdx.y") + sch.bind(bx, "blockIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) - # Co-operative Memory Fetch - ro, rv = sch.split(sch.get_loops(rmat_block)[-1], [None, VecSize]) - sch.bind(ro, "threadIdx.x") - sch.vectorize(rv) + inp = sch.cache_read(matmul_block, read_buffer_index=0, storage_scope="local") + sch.compute_at(inp, k3, preserve_unit_loops=True) + sch.vectorize(sch.get_loops(inp)[-1]) - wv = sch.get_loops(wmat_block)[-1] - sch.vectorize(wv) + sch.unroll(unr) + sch.unroll(k3) - # Scale and Quant Cache if dequant_block is not None: - qb = sch.cache_read(dequant_block, 0, "local") - sb = sch.cache_read(dequant_block, 1, "local") - sch.compute_at(sb, k1) - sch.compute_at(qb, k2) - sch.set_scope(sb, 0, "local") - sch.set_scope(qb, 0, "local") - sch.vectorize(sch.get_loops(qb)[-1]) - sch.vectorize(sch.get_loops(sb)[-1]) + Aq_local = sch.cache_read(dequant_block, read_buffer_index=0, storage_scope="local") + sch.compute_at(Aq_local, k2, preserve_unit_loops=True) + sch.vectorize(sch.get_loops(Aq_local)[-1]) + As_local = sch.cache_read(dequant_block, read_buffer_index=1, storage_scope="local") + sch.compute_at(As_local, k1, preserve_unit_loops=True) + sch.vectorize(sch.get_loops(As_local)[-1]) + sch.vectorize(sch.get_loops(dequant_block)[-1]) - if epilogue_block is not None: - sch.reverse_compute_at(epilogue_block, mi, preserve_unit_loops=True) - sch.set_scope(wmat_block, 0, "local") - sch.compute_inline(wmat_block) - sch.vectorize(sch.get_loops(epilogue_block)[-1]) + sch.reverse_compute_at(matmul_reindex, ty) + o_ur, o_vec = sch.get_loops(matmul_reindex)[-2:] + sch.vectorize(o_vec) + sch.unroll(o_ur) + sch.decompose_reduction(matmul_block, k1) - sch.decompose_reduction(matmul_block, k0) return sch diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index dc5276e62a5f..83b52efc3a69 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -634,49 +634,68 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) # with T.block("root"): - inp0_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="shared") - matmul_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local") + inp0_reindex_pad = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16))) + matmul_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(4096)), scope="local") + inp0_reindex_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), scope="local") + for i0 in T.thread_binding(T.int64(1), thread="blockIdx.z"): + for i1_0 in T.thread_binding(((m + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"): + for i2_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"): + for i1_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i1_2 in T.vectorized(T.int64(16)): + with T.block("inp0_reindex_pad"): + v0 = T.axis.spatial(T.int64(1), i0) + v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i1_0 * T.int64(64) + i1_1 * T.int64(16) + i1_2) + v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(32) + i2_1) + T.where((i1_0 * T.int64(4) + i1_1) * T.int64(16) + i1_2 < (m + T.int64(15)) // T.int64(16) * T.int64(16)) + T.reads(inp0[v0, v1, v2]) + T.writes(inp0_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)]) + inp0_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0)) for i2_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"): - for i0_i1_fused_0 in T.thread_binding((m + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + for i0_i1_fused_0 in T.thread_binding(((m + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"): for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i0_i1_fused_2_init in range(T.int64(4)): + for i0_i1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for i0_i1_fused_2_init in T.unroll(T.int64(16)): for i2_2_init in T.vectorized(T.int64(8)): with T.block("matmul_init"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2_init) + v_i1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2_init) v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2_init < (m + T.int64(15)) // T.int64(16) * T.int64(16)) T.reads() T.writes(matmul_pad_local[v_i0, v_i1, v_i2]) matmul_pad_local[v_i0, v_i1, v_i2] = T.float32(0) - for k_0 in range(T.int64(16)): - for ax0 in range(T.int64(4)): - for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax1_1 in T.vectorized(T.int64(8)): - with T.block("inp0_pad"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) - v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1) - T.reads(inp0[v0, v1, v2]) - T.writes(inp0_pad_shared[v0, v1, v2]) - inp0_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0)) - for k_1, k_2, k_3, i0_i1_fused_2 in T.grid(T.int64(8), T.int64(4), T.int64(8), T.int64(4)): - for i2_2 in T.vectorized(T.int64(8)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2) - v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) - T.reads(matmul_pad_local[v_i0, v_i1, v_i2], inp0_pad_shared[v_i0, v_i1, v_k], inp1[v_k, v_i2]) - T.writes(matmul_pad_local[v_i0, v_i1, v_i2]) - matmul_pad_local[v_i0, v_i1, v_i2] = matmul_pad_local[v_i0, v_i1, v_i2] + inp0_pad_shared[v_i0, v_i1, v_k] * inp1[v_k, v_i2] - for ax0 in range(T.int64(4)): + for k_0, k_1 in T.grid(T.int64(128), T.int64(4)): + for k_2 in T.unroll(T.int64(8)): + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + for ax3 in T.vectorized(T.int64(16)): + with T.block("inp0_reindex_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial((m + T.int64(15)) // T.int64(16), i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 + ax1) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2) + v3 = T.axis.spatial(T.int64(16), ax3) + T.where(i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 < (m + T.int64(15)) // T.int64(16)) + T.reads(inp0_reindex_pad[v0, v1, v2, v3]) + T.writes(inp0_reindex_pad_local[v0, v1, v2, v3]) + inp0_reindex_pad_local[v0, v1, v2, v3] = inp0_reindex_pad[v0, v1, v2, v3] + for i0_i1_fused_2 in T.unroll(T.int64(16)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial((m + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2) + v_i2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2 < (m + T.int64(15)) // T.int64(16) * T.int64(16)) + T.reads(matmul_pad_local[v_i0, v_i1, v_i2], inp0_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)], inp1[v_k, v_i2]) + T.writes(matmul_pad_local[v_i0, v_i1, v_i2]) + matmul_pad_local[v_i0, v_i1, v_i2] = matmul_pad_local[v_i0, v_i1, v_i2] + inp0_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)] * inp1[v_k, v_i2] + for ax0 in T.unroll(T.int64(16)): for ax1 in T.vectorized(T.int64(8)): with T.block("matmul_pad"): v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) + v1 = T.axis.spatial(m, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) - T.where((i0_i1_fused_0 - (m + T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < m) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (m + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < m) T.reads(matmul_pad_local[v0, v1, v2]) T.writes(matmul[v0, v1, v2]) matmul[v0, v1, v2] = matmul_pad_local[v0, v1, v2] @@ -729,75 +748,94 @@ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") # with T.block("root"): dequantize_intermediate_intermediate_local = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local") - rms_norm130_pad_shared = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16", scope="shared") - matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16", scope="local") + rms_norm130_reindex_pad = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), "float16") + matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16) * T.int64(16), T.int64(12288)), "float16", scope="local") + rms_norm130_reindex_pad_local = T.alloc_buffer((T.int64(1), (seq_len + T.int64(15)) // T.int64(16), T.int64(4096), T.int64(16)), "float16", scope="local") lv452_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32", scope="local") lv453_local = T.alloc_buffer((T.int64(128), T.int64(12288)), "float16", scope="local") + for i0 in T.thread_binding(T.int64(1), thread="blockIdx.z"): + for i1_0 in T.thread_binding(((seq_len + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"): + for i2_0 in T.thread_binding(T.int64(128), thread="blockIdx.x"): + for i1_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for i1_2 in T.vectorized(T.int64(16)): + with T.block("rms_norm130_reindex_pad"): + v0 = T.axis.spatial(T.int64(1), i0) + v1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i1_0 * T.int64(64) + i1_1 * T.int64(16) + i1_2) + v2 = T.axis.spatial(T.int64(4096), i2_0 * T.int64(32) + i2_1) + T.where((i1_0 * T.int64(4) + i1_1) * T.int64(16) + i1_2 < (seq_len + T.int64(15)) // T.int64(16) * T.int64(16)) + T.reads(rms_norm130[v0, v1, v2]) + T.writes(rms_norm130_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)]) + rms_norm130_reindex_pad[v0, v1 // T.int64(16), v2, v1 % T.int64(16)] = T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0)) for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"): - for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) // T.int64(32), thread="blockIdx.y"): + for i0_i1_fused_0 in T.thread_binding(((seq_len + T.int64(15)) // T.int64(16) * T.int64(16) + T.int64(63)) // T.int64(64), thread="blockIdx.y"): for i2_1 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.y"): - for i0_i1_fused_2_init in range(T.int64(4)): + for i0_i1_fused_1 in T.thread_binding(T.int64(4), thread="threadIdx.y"): + for i0_i1_fused_2_init in T.unroll(T.int64(16)): for i2_2_init in T.vectorized(T.int64(8)): with T.block("matmul_init"): v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2_init) + v_i1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2_init) v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2_init) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2_init < (seq_len + T.int64(15)) // T.int64(16) * T.int64(16)) T.reads() T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = T.float16(0) - for k_0 in range(T.int64(16)): - for ax0 in range(T.int64(4)): - for ax1_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): - for ax1_1 in T.vectorized(T.int64(8)): - with T.block("rms_norm130_pad"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0) - v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1) - T.reads(rms_norm130[v0, v1, v2]) - T.writes(rms_norm130_pad_shared[v0, v1, v2]) - rms_norm130_pad_shared[v0, v1, v2] = T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0)) - for k_1 in range(T.int64(8)): - for ax0 in T.vectorized(T.int64(8)): + for k_0 in range(T.int64(128)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): with T.block("lv453_local"): - v0 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + k_1) - v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) + v0 = T.axis.spatial(T.int64(128), k_0 + ax0) + v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) T.reads(lv453[v0, v1]) T.writes(lv453_local[v0, v1]) lv453_local[v0, v1] = lv453[v0, v1] - for k_2 in range(T.int64(4)): - for ax0 in T.vectorized(T.int64(8)): + for k_1 in range(T.int64(4)): + for ax0 in range(T.int64(1)): + for ax1 in T.vectorized(T.int64(8)): with T.block("lv452_local"): - v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(32) + k_1 * T.int64(4) + k_2) - v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) + v0 = T.axis.spatial(T.int64(512), k_0 * T.int64(4) + k_1 + ax0) + v1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) T.reads(lv452[v0, v1]) T.writes(lv452_local[v0, v1]) lv452_local[v0, v1] = lv452[v0, v1] - for k_3 in range(T.int64(8)): - for ax0 in T.vectorized(T.int64(8)): - with T.block("dequantize"): - v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) - v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) - T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1]) - T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1]) - dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1] - for i0_i1_fused_2 in range(T.int64(4)): - for i2_2 in T.vectorized(T.int64(8)): - with T.block("matmul_update"): - v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) - v_i1 = T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2) - v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) - v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3) - T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm130_pad_shared[v_i0, v_i1, v_k], dequantize_intermediate_intermediate_local[v_k, v_i2]) - T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) - matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_pad_shared[v_i0, v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2] - for ax0, ax1 in T.grid(T.int64(1), T.int64(4)): - for ax2 in T.vectorized(T.int64(8)): + for k_2 in T.unroll(T.int64(8)): + for ax0 in T.vectorized(T.int64(8)): + with T.block("dequantize"): + v_i0 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2) + v_i1 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0) + T.reads(lv452_local[v_i0 // T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1]) + T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1]) + dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1], T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + for ax3 in T.vectorized(T.int64(16)): + with T.block("rms_norm130_reindex_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16), i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 + ax1) + v2 = T.axis.spatial(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2 + ax2) + v3 = T.axis.spatial(T.int64(16), ax3) + T.where(i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 < (seq_len + T.int64(15)) // T.int64(16)) + T.reads(rms_norm130_reindex_pad[v0, v1, v2, v3]) + T.writes(rms_norm130_reindex_pad_local[v0, v1, v2, v3]) + rms_norm130_reindex_pad_local[v0, v1, v2, v3] = rms_norm130_reindex_pad[v0, v1, v2, v3] + for i0_i1_fused_2 in T.unroll(T.int64(16)): + for i2_2 in T.vectorized(T.int64(8)): + with T.block("matmul_update"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i1 = T.axis.spatial((seq_len + T.int64(15)) // T.int64(16) * T.int64(16), i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + i0_i1_fused_2) + v_i2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2) + v_k = T.axis.reduce(T.int64(4096), k_0 * T.int64(32) + k_1 * T.int64(8) + k_2) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1) * T.int64(16) + i0_i1_fused_2 < (seq_len + T.int64(15)) // T.int64(16) * T.int64(16)) + T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2], rms_norm130_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)], dequantize_intermediate_intermediate_local[v_k, v_i2]) + T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2]) + matmul_intermediate_pad_local[v_i0, v_i1, v_i2] = matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_reindex_pad_local[v_i0, v_i1 // T.int64(16), v_k, v_i1 % T.int64(16)] * dequantize_intermediate_intermediate_local[v_k, v_i2] + for ax0 in T.unroll(T.int64(16)): + for ax1 in T.vectorized(T.int64(8)): with T.block("T_add"): - v_ax0 = T.axis.spatial(T.int64(1), ax0) - v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1) - v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax2) - T.where(i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1 < seq_len) + v_ax0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_ax1 = T.axis.spatial(seq_len, i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0) + v_ax2 = T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax1) + T.where((i0_i1_fused_0 * T.int64(4) + i0_i1_fused_1 - (seq_len + T.int64(15)) // T.int64(16) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 * T.int64(64) + i0_i1_fused_1 * T.int64(16) + ax0 < seq_len) T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2], transformer_h_0_attn_c_attn_bias3[v_ax2]) T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2]) T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] + transformer_h_0_attn_c_attn_bias3[v_ax2]