diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index 798f00423f7b..d111bdb42abb 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -65,6 +65,23 @@ Optional ParseAnnotate(const Schedule& sch, const Instruction& inst, return Downcast(inst->inputs[0]); } +/*! + * \brief Parse instruction: sch.annotate(..., attr::warp_execution) + * \param sch The schedule + * \param inst The instruction to be parsed + * \return Whether ths parsing is successful + */ +bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) { + static InstructionKind inst_kind_annotate = InstructionKind::Get("Annotate"); + if (!inst->kind.same_as(inst_kind_annotate)) { + return false; + } + ICHECK_EQ(inst->inputs.size(), 2); + ICHECK_EQ(inst->attrs.size(), 1); + String ann_key = Downcast(inst->attrs[0]); + return ann_key == attr::warp_execution; +} + } // namespace tir namespace meta_schedule { @@ -76,7 +93,14 @@ namespace meta_schedule { class RewriteCooperativeFetchNode : public PostprocNode { public: // Inherited from PostprocNode - void InitializeWithTuneContext(const TuneContext& context) final {} + void InitializeWithTuneContext(const TuneContext& context) final { + if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + this->thread_warp_size_ = v.value()->value; + } else { + TVM_PY_LOG(INFO, context->logging_func) << "'thread_warp_size' is not defined in the target"; + } + } + // Inherited from PostprocNode bool Apply(const tir::Schedule& sch) final; @@ -84,6 +108,9 @@ class RewriteCooperativeFetchNode : public PostprocNode { static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch"; TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode); + + private: + int thread_warp_size_ = -1; }; bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { @@ -101,6 +128,10 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) { thread_extent_y = new_thread_extent.value()->value; continue; } + if (tir::ParseWarpExecutionAnn(sch, inst)) { + thread_extent_x = thread_warp_size_; + continue; + } Optional opt_block_rv = tir::ParseAnnotate(sch, inst, &vector_lane); if (!opt_block_rv.defined()) { continue; diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py index 5460c5900946..e55f693e72d3 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -17,6 +17,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm +import tvm.testing from tvm import tir from tvm.meta_schedule import TuneContext from tvm.meta_schedule.postproc import RewriteCooperativeFetch @@ -99,6 +100,108 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: C[v0, v1] = C_local[v0, v1] +@tvm.script.ir_module +class WarpExecutionAfterRewrite: + @T.prim_func + def main( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(0, 16, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(0, 16, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(0, 8, thread="threadIdx.y"): + for i2_0 in T.serial(0, 1): + for ax0_ax1_fused_0 in T.serial(0, 1024): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding( + 0, 32, thread="threadIdx.x" + ): + with T.block("A_shared"): + v0 = T.axis.spatial( + 512, + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 32 + + ax0_ax1_fused_2 + ) + // 512, + ) + v1 = T.axis.spatial( + 512, + ( + ax0_ax1_fused_0 * 256 + + ax0_ax1_fused_1 * 32 + + ax0_ax1_fused_2 + ) + % 512, + ) + T.reads([A[v0, v1]]) + T.writes([A_shared[v0, v1]]) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused_0 in T.serial(0, 32): + for ax0_ax1_fused_1 in T.thread_binding(0, 8, thread="threadIdx.y"): + for ax0_ax1_fused_2 in T.thread_binding( + 0, 32, thread="threadIdx.x" + ): + for ax0_ax1_fused_3 in T.vectorized(0, 2): + with T.block("B_shared"): + v0 = T.axis.spatial( + 512, + ( + ax0_ax1_fused_0 * 512 + + ax0_ax1_fused_1 * 64 + + ax0_ax1_fused_2 * 2 + + ax0_ax1_fused_3 + ) + // 32, + ) + v1 = T.axis.spatial( + 512, + i0_0_i1_0_fused * 32 + + ( + ax0_ax1_fused_0 * 512 + + ax0_ax1_fused_1 * 64 + + ax0_ax1_fused_2 * 2 + + ax0_ax1_fused_3 + ) + % 32, + ) + T.reads([B[v0, v1]]) + T.writes([B_shared[v0, v1]]) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(16, 2, 2, 32, 16, 2): + with T.block("C"): + i = T.axis.spatial(512, i0_1_i1_1_fused * 32 + i0_3 * 16 + i0_4) + j = T.axis.spatial( + 512, + i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + i1_3 * 2 + i1_4, + ) + k = T.axis.reduce(512, i2_0 * 512 + i2_1 * 32 + i2_2) + T.reads([A_shared[i, k], B_shared[k, j]]) + T.writes([C_local[i, j]]) + T.block_attr({"warp_execution": 1}) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(32, 4): + with T.block("C_local"): + v0 = T.axis.spatial(512, i0_1_i1_1_fused * 32 + ax0) + v1 = T.axis.spatial( + 512, i0_0_i1_0_fused * 32 + i0_2_i1_2_fused * 4 + ax1 + ) + T.reads([C_local[v0, v1]]) + T.writes([C[v0, v1]]) + C[v0, v1] = C_local[v0, v1] + + # pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks # fmt: on @@ -147,5 +250,51 @@ def test_rewrite_cooperative_fetch(): tvm.ir.assert_structural_equal(sch.mod, AfterRewrite0) +def test_rewrite_warp_execution(): + mod = create_prim_func(te_workload.matmul(n=512, m=512, k=512)) + target = _target() + ctx = _create_context(mod, target) + + sch = tir.Schedule(mod, debug_mask="all") + # fmt: off + # pylint: disable=line-too-long,invalid-name + b0 = sch.get_block(name="C", func_name="main") + b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + l2, l3, l4 = sch.get_loops(block=b0) + sch.annotate(b0, "warp_execution", 1) + v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64, decision=[1, 16, 1, 2, 16]) + l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9]) + v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64, decision=[16, 1, 8, 2, 2]) + l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19]) + v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64, decision=[1, 16, 32]) + l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27]) + sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24) + l31 = sch.fuse(l10, l20) + sch.bind(loop=l31, thread_axis="blockIdx.x") + l32 = sch.fuse(l11, l21) + sch.bind(loop=l32, thread_axis="vthread.x") + l33 = sch.fuse(l12, l22) + sch.bind(loop=l33, thread_axis="threadIdx.y") + b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared") + sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True) + _, _, _, _, l39, l40 = sch.get_loops(block=b34) + l41 = sch.fuse(l39, l40) + _, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4, decision=[262144, 1]) + sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43) + b44 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True) + _, _, _, _, l49, l50 = sch.get_loops(block=b44) + l51 = sch.fuse(l49, l50) + _, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4, decision=[8192, 2]) + sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53) + sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True) + # pylint: enable=line-too-long,invalid-name + # fmt: on + sch.enter_postproc() + assert ctx.postprocs[0].apply(sch) + print(sch.mod["main"].script()) + tvm.ir.assert_structural_equal(sch.mod, WarpExecutionAfterRewrite) + + if __name__ == "__main__": - test_rewrite_cooperative_fetch() + tvm.testing.main()