diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index e17f955963..bde0d40a11 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -152,7 +152,6 @@ class ScheduleRule : public runtime::ObjectRef { Optional vector_load_max_len, // Optional> reuse_read, // Optional> reuse_write); - /*! * \brief Create a rule: add-rfactor to some blocks if needed * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index 3fe3f0fb3b..c269b0e617 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -21,3 +21,4 @@ from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule +from .add_rfactor import AddRFactor \ No newline at end of file diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py index 176c804cd5..748684d344 100644 --- a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py +++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py @@ -25,47 +25,25 @@ @register_object("meta_schedule.AddRFactor") class AddRFactor(ScheduleRule): - """Rule that inlines spatial blocks if it satisfies some conditions + """Rules for add-rfactor to some blocks if needed. Parameters ---------- - into_producer : bool - If allows to inline a block into its producer - into_consumer : bool - If allows to inline a block into its consumer - into_cache_only : bool - If it only allows to inline into a block generated by cache_read/write - inline_const_tensor : bool - Always inline constant tensors - disallow_if_then_else : bool - Always disallow if-then-else-like constructs - require_injective : bool - Always require the read-to-write mapping to be ordered - require_ordered : bool - Always require the read-to-write mapping to be injective - disallow_op : Optional[List[str]] - The operators that are disallowed in auto inline + max_jobs_per_core: int + The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU + parallelism, i.e. `num_cores * max_jobs_per_core`. + Use -1 to disable parallelism. + max_innermost_factor: Optional[int] = None + The maximum size of the innermost factor. NullOpt means no limit. """ def __init__( self, - into_producer: bool, - into_consumer: bool, - into_cache_only: bool, - inline_const_tensor: bool, - disallow_if_then_else: bool, - require_injective: bool, - require_ordered: bool, - disallow_op: Optional[List[str]] = None, + max_jobs_per_core: int = 16, + max_innermost_factor: Optional[int] = None, ) -> None: self.__init_handle_by_constructor__( - _ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member - into_producer, - into_consumer, - into_cache_only, - inline_const_tensor, - disallow_if_then_else, - require_injective, - require_ordered, - disallow_op, + _ffi_api.ScheduleRuleAddRFactor, # type: ignore # pylint: disable=no-member + max_jobs_per_core, + max_innermost_factor, ) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index 03973488ac..a01fd3aa84 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -18,6 +18,7 @@ from typing import List from tvm.meta_schedule.schedule_rule import ( + AddRFactor, AutoInline, MultiLevelTiling, ParallelizeVectorizeUnroll, @@ -33,6 +34,7 @@ def get(target: Target) -> List[ScheduleRule]: if target.kind.name == "llvm": return [ auto_inline(target), + add_rfactor(target), multi_level_tiling(target), parallel_vectorize_unroll(target), ] @@ -159,3 +161,10 @@ def random_compute_location(target: Target) -> ScheduleRule: if target.kind.name == "llvm": return RandomComputeLocation() raise NotImplementedError(f"{target.kind.name} is not supported") + + +def add_rfactor(target: Target) -> ScheduleRule: + """Default schedule rules for with add_rfactor""" + if target.kind.name == "llvm": + return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64) + raise NotImplementedError(f"{target.kind.name} is not supported") \ No newline at end of file diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index 8acd8d277a..7fc94ab118 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -21,59 +21,6 @@ namespace tvm { namespace meta_schedule { -/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/ - -/*! - * \brief Reorder the reduction loops to innermost positions if needed. - * \param sch The schedule - * \param block_rv The block where to apply the reorder - * \param fused_reduce_loop The fusion-generated loop to return. - * \param num_spatial_loops The number of spatial loops to return. - * \note Before invoking this helper function, make sure that the block has only spatial and - * reduction loop axes. - */ -void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv, - tir::LoopRV* fused_reduce_loop, size_t* num_spatial_loops) { - Array loops = sch->GetLoops(block_rv); - Array loop_srefs; - for (const tir::LoopRV& loop_rv : loops) { - loop_srefs.push_back(sch->GetSRef(loop_rv)); - } - - Array new_order; - // Step 1. Add spatial loops. - *num_spatial_loops = 0; - for (size_t i = 0; i < loops.size(); ++i) { - if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) { - new_order.push_back(loops[i]); - (*num_spatial_loops)++; - } - } - // Step 2. Add reduction loops. - Array reduction_loops; - for (size_t i = 0; i < loops.size(); ++i) { - if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { - new_order.push_back(loops[i]); - reduction_loops.push_back(loops[i]); - } - } - // Step 3. Apply reordering if new_order differs from the original order. - ICHECK_EQ(new_order.size(), loops.size()); - for (size_t i = 0; i < loops.size(); ++i) { - if (!new_order[i].same_as(loops[i])) { - sch->Reorder(new_order); - break; - } - } - // Step 4. Fuse all the reduction loops if there are multiple reduction loops. - CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop"; - if (reduction_loops.size() > 1) { - *fused_reduce_loop = sch->Fuse(reduction_loops); - } else { - *fused_reduce_loop = reduction_loops[0]; - } -} - class AddRFactorNode : public ScheduleRuleNode { public: // Inherited from ScheduleRuleNode @@ -88,58 +35,7 @@ class AddRFactorNode : public ScheduleRuleNode { } // Inherited from ScheduleRuleNode - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { - if (!NeedAddRFactor(sch, block_rv)) { - return {sch}; - } - // Make a copy of the original schedule. - tir::Schedule ori_sch = sch->Copy(); - ori_sch->Seed(sch->ForkSeed()); - - // Reorder the loop axes if reduction loops are not innermost. - // After the reordering, fuse all the reduction loops. - size_t num_spatial_loops; - tir::LoopRV fused_reduce_loop; - ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops); - - // Split the fused reduction loop. - Array factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); - const Array& split_loops = - sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); - - Array res; - for (const tir::LoopRV& split_loop : split_loops) { - tir::Schedule sch_tmp = sch->Copy(); - sch_tmp->Seed(sch->ForkSeed()); - const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); - Array axes = sch_tmp->GetLoops(block_rf); - ICHECK_GT(axes.size(), num_spatial_loops); - res.push_back(sch_tmp); - } - - res.push_back(ori_sch); - return res; - } - - bool NeedAddRFactor(const tir::Schedule& sch, const tir::BlockRV& block_rv) { - tir::StmtSRef block_sref = sch->GetSRef(block_rv); - tir::ScheduleState state = sch->state(); - const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); - // Cond 1. The block has no annotations - if (!block->annotations.empty()) { - return false; - } - // Cond 2. The block has only one write buffer - if (block->writes.size() != 1) { - return false; - } - if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, - max_parallel_basic_) || - HasCacheWriteBlock(sch, block_rv, 0)) { - return false; - } - return true; - } + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv); public: /*! @@ -164,7 +60,7 @@ class AddRFactorNode : public ScheduleRuleNode { static constexpr const char* _type_key = "meta_schedule.AddRFactor"; TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode); -}; // namespace meta_schedule +}; ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, Optional max_innermost_factor) { @@ -176,5 +72,45 @@ ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, return ScheduleRule(n); } +Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { + tir::StmtSRef block_sref = sch->GetSRef(block_rv); + if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, + max_parallel_basic_)) { + return {sch}; + } + + // Make a copy of the original schedule. + tir::Schedule ori_sch = sch->Copy(); + ori_sch->Seed(sch->ForkSeed()); + + // Reorder the loop axes if reduction loops are not innermost. + // After the reordering, fuse all the reduction loops. + size_t num_spatial_loops; + tir::LoopRV fused_reduce_loop; + ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops); + + // Split the fused reduction loop. + Array factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); + const Array& split_loops = + sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); + + Array res; + for (const tir::LoopRV& split_loop : split_loops) { + tir::Schedule sch_tmp = sch->Copy(); + sch_tmp->Seed(sch->ForkSeed()); + const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops); + Array axes = sch_tmp->GetLoops(block_rf); + ICHECK_GT(axes.size(), num_spatial_loops); + res.push_back(sch_tmp); + } + + res.push_back(ori_sch); + return res; +} + +TVM_REGISTER_NODE_TYPE(AddRFactorNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") + .set_body_typed(ScheduleRule::AddRFactor); + } // namespace meta_schedule } // namespace tvm \ No newline at end of file diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index d3b4450f19..77d68383c1 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -303,6 +303,60 @@ inline Optional ApplyTrace(const IRModule& mod, const tir::Trace& return sch; } +/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/ + +/*! + * \brief Reorder the reduction loops to innermost positions if needed. + * \param sch The schedule + * \param block_rv The block where to apply the reorder + * \param fused_reduce_loop The fusion-generated loop to return. + * \param num_spatial_loops The number of spatial loops to return. + * \note Before invoking this helper function, make sure that the block has only spatial and + * reduction loop axes. + */ +inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv, + tir::LoopRV* fused_reduce_loop, + size_t* num_spatial_loops) { + Array loops = sch->GetLoops(block_rv); + Array loop_srefs; + for (const tir::LoopRV& loop_rv : loops) { + loop_srefs.push_back(sch->GetSRef(loop_rv)); + } + + Array new_order; + // Step 1. Add spatial loops. + *num_spatial_loops = 0; + for (size_t i = 0; i < loops.size(); ++i) { + if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) { + new_order.push_back(loops[i]); + (*num_spatial_loops)++; + } + } + // Step 2. Add reduction loops. + Array reduction_loops; + for (size_t i = 0; i < loops.size(); ++i) { + if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { + new_order.push_back(loops[i]); + reduction_loops.push_back(loops[i]); + } + } + // Step 3. Apply reordering if new_order differs from the original order. + ICHECK_EQ(new_order.size(), loops.size()); + for (size_t i = 0; i < loops.size(); ++i) { + if (!new_order[i].same_as(loops[i])) { + sch->Reorder(new_order); + break; + } + } + // Step 4. Fuse all the reduction loops if there are multiple reduction loops. + CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop"; + if (reduction_loops.size() > 1) { + *fused_reduce_loop = sch->Fuse(reduction_loops); + } else { + *fused_reduce_loop = reduction_loops[0]; + } +} + } // namespace meta_schedule } // namespace tvm diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index ae68fbb231..cded75cfeb 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -550,16 +550,6 @@ bool HasOp(const Stmt& stmt, const Array& ops); */ bool HasIfThenElse(const Stmt& stmt); -/*! - * \brief Checks if the given block has cache write blocks. - * \param sch The traced schedule. - * \param block_rv The given block. - * \param write_buffer_index The index of the buffer in block's write region - * \return A boolean indicating whether the block has its cache write block in the trace. - */ -bool HasCacheWriteBlock(const tir::Schedule& sch, const BlockRV& block_rv, - const int& write_buffer_index); - /*! * \brief Checks if a block could be successfully computed inline into its consumer * \param self The schedule state diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4d5ee8aa30..01d5e89c7a 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1724,7 +1724,17 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); Array loops = tir::GetLoops(block_sref); - // Cond 1. The block is a reduction block and has trivial binding. + // Cond 1. The block has no annotations + if (!block->annotations.empty()) { + return false; + } + + // Cond 2. The block has only one write buffer + if (block->writes.size() != 1) { + return false; + } + + // Cond 3. The block is a reduction block and has trivial binding. const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, // /*require_stage_pipeline=*/false, // /*require_subtree_compact_dataflow=*/false); @@ -1733,7 +1743,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // return false; } - // Cond 2. Every the loop axis must be either spatial axis or reduction axis. + // Cond 4. Every the loop axis must be either spatial axis or reduction axis. for (const tir::StmtSRef& loop_sref : loops) { const tir::IterVarType& type = GetLoopIterType(loop_sref); if (type != tir::kDataPar && type != tir::kCommReduce) { @@ -1741,16 +1751,16 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } } - // Cond 3. Whether there is at least one reduction loop. - // Cond 4. The loops are continuous, and the body of the innermost loop is exactly the block. + // Cond 5. Whether there is at least one reduction loop. + // Cond 6. The loops are continuous, and the body of the innermost loop is exactly the block. bool has_reduction_loop = false; for (size_t i = 0; i < loops.size(); ++i) { - // Cond 3. + // Cond 5. if (GetLoopIterType(loops[i]) == tir::kCommReduce) { has_reduction_loop = true; } - // Cond 4. + // Cond 6. const ForNode* loop_i = TVM_SREF_TO_FOR(loop_i, loops[i]); if (i < loops.size() - 1) { const ForNode* loop_i1 = TVM_SREF_TO_FOR(loop_i1, loops[i + 1]); @@ -1768,14 +1778,14 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // return false; } - // Cond 5. Can successfully calculating the cumulative loop length. + // Cond 7. Can successfully calculating the cumulative loop length. int64_t cum_space_len, cum_reduce_len; std::tie(cum_space_len, cum_reduce_len) = GetCumulativeSpaceAndReductionLength(self, block_sref); if (cum_space_len == -1 || cum_reduce_len == -1) { return false; } - // Cond 6. + // Cond 8. if (NeedsMultiLevelTiling(self, block_sref)) { // Do not use rfactor/cross-thread-reduction if we have enough parallelism on spatial loops. return !(cum_space_len >= cum_reduce_len || cum_space_len > max_parallel_extent); @@ -1836,23 +1846,5 @@ bool HasIfThenElse(const Stmt& stmt) { return has_branch; } -bool HasCacheWriteBlock(const Schedule& sch, const BlockRV& block_rv, - const int& write_buffer_index) { - static tir::InstructionKind cache_write = tir::InstructionKind::Get("CacheWrite"); - ICHECK(sch->trace().defined()); - const Trace& trace = sch->trace().value(); - for (const Instruction& inst : trace->insts) { - if (inst->kind.same_as(cache_write)) { - CHECK_EQ(inst->inputs.size(), 1); - const BlockRV& input_rv = Downcast(inst->inputs[0]); - int buffer_index = Downcast(inst->attrs[0])->value; - if (block_rv.same_as(input_rv) && buffer_index == write_buffer_index) { - return true; - } - } - } - return false; -} - } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py index 6d102c0b48..300b5aeedf 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py @@ -17,12 +17,12 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply -from tvm.meta_schedule.testing.schedule_rule import multi_level_tiling +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.schedule_rule import add_rfactor from tvm.meta_schedule.testing.space_generation import check_trace from tvm.meta_schedule.tune_context import TuneContext -from tvm.te import create_prim_func -from tvm.meta_schedule.testing import te_workload from tvm.target import Target +from tvm.te.operation import create_prim_func def _create_context(mod, target, rule) -> TuneContext: @@ -41,228 +41,38 @@ def _create_context(mod, target, rule) -> TuneContext: def test_cpu_matmul(): expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "l2, l3, l4 = sch.get_loops(block=b0)", - "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", - "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", - "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", - "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", - "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", - "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", - "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)", - ], - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "l2, l3, l4 = sch.get_loops(block=b0)", - "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", - "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", - "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", - "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", - "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", - "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", - "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=True)", - ], + [], [ 'b0 = sch.get_block(name="C", func_name="main")', "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - ], - ] - target = Target("llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul( - n=512, - m=512, - k=512, - ) - ), - target=target, - rule=multi_level_tiling(target=target), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 3 - check_trace(spaces, expected) - - -def test_cpu_matmul_relu(): - # pylint: disable=line-too-long - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - "b1, = sch.get_consumers(block=b0)", - "l2, l3, l4 = sch.get_loops(block=b0)", - "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", - "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", - "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", - "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", - "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", - "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", - "sch.reverse_compute_at(block=b1, loop=l18, preserve_unit_loops=True)", - ], - [ - 'b0 = sch.get_block(name="C", func_name="main")', - "b1, = sch.get_consumers(block=b0)", - "l2, l3, l4 = sch.get_loops(block=b0)", - "v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l9, l10, l11, l12 = sch.split(loop=l2, factors=[v5, v6, v7, v8])", - "v13, v14, v15, v16 = sch.sample_perfect_tile(loop=l3, n=4, max_innermost_factor=64)", - "l17, l18, l19, l20 = sch.split(loop=l3, factors=[v13, v14, v15, v16])", - "v21, v22 = sch.sample_perfect_tile(loop=l4, n=2, max_innermost_factor=64)", - "l23, l24 = sch.split(loop=l4, factors=[v21, v22])", - "sch.reorder(l9, l17, l10, l18, l23, l11, l19, l24, l12, l20)", - "sch.reverse_compute_at(block=b1, loop=l17, preserve_unit_loops=True)", + "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l6, l7 = sch.split(loop=l3, factors=[v4, v5])", + "b8 = sch.rfactor(loop=l7, factor_axis=2)", ], [ 'b0 = sch.get_block(name="C", func_name="main")', "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l6, l7 = sch.split(loop=l3, factors=[v4, v5])", + "b8 = sch.rfactor(loop=l6, factor_axis=2)", ], ] - # pylint: enable=line-too-long - target = Target("llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul_relu( - n=512, - m=512, - k=512, - ) - ), - target=target, - rule=multi_level_tiling(target=target), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 3 - check_trace(spaces, expected) - - -def test_cuda_matmul(): - # pylint: disable=line-too-long - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "l2, l3, l4 = sch.get_loops(block=b0)", - "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", - "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", - "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", - "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", - "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", - "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", - "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", - "l31 = sch.fuse(l10, l20)", - 'sch.bind(loop=l31, thread_axis="blockIdx.x")', - "l32 = sch.fuse(l11, l21)", - 'sch.bind(loop=l32, thread_axis="vthread.x")', - "l33 = sch.fuse(l12, l22)", - 'sch.bind(loop=l33, thread_axis="threadIdx.x")', - 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)", - "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", - "l41 = sch.fuse(l39, l40)", - "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', - 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)", - "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", - "l51 = sch.fuse(l49, l50)", - "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', - "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)", - ] - ] - # pylint: enable=line-too-long - target = Target("cuda", host="llvm") + target = Target("llvm --num-cores=32") ctx = _create_context( create_prim_func( te_workload.matmul( - n=512, - m=512, - k=512, - ) - ), - target=target, - rule=multi_level_tiling(target=target), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - check_trace(spaces, expected) - - -def test_cuda_matmul_relu(): - # pylint: disable=line-too-long - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'b1 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "l2, l3, l4 = sch.get_loops(block=b0)", - "v5, v6, v7, v8, v9 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", - "l10, l11, l12, l13, l14 = sch.split(loop=l2, factors=[v5, v6, v7, v8, v9])", - "v15, v16, v17, v18, v19 = sch.sample_perfect_tile(loop=l3, n=5, max_innermost_factor=64)", - "l20, l21, l22, l23, l24 = sch.split(loop=l3, factors=[v15, v16, v17, v18, v19])", - "v25, v26, v27 = sch.sample_perfect_tile(loop=l4, n=3, max_innermost_factor=64)", - "l28, l29, l30 = sch.split(loop=l4, factors=[v25, v26, v27])", - "sch.reorder(l10, l20, l11, l21, l12, l22, l28, l29, l13, l23, l30, l14, l24)", - "l31 = sch.fuse(l10, l20)", - 'sch.bind(loop=l31, thread_axis="blockIdx.x")', - "l32 = sch.fuse(l11, l21)", - 'sch.bind(loop=l32, thread_axis="vthread.x")', - "l33 = sch.fuse(l12, l22)", - 'sch.bind(loop=l33, thread_axis="threadIdx.x")', - 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)", - "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", - "l41 = sch.fuse(l39, l40)", - "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', - 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)", - "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", - "l51 = sch.fuse(l49, l50)", - "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', - "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)", - ] - ] - # pylint: enable=line-too-long - target = Target("cuda", host="llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul_relu( - n=512, - m=512, + n=4, + m=4, k=512, ) ), target=target, - rule=multi_level_tiling(target=target), + rule=add_rfactor(target=target), ) spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 + assert len(spaces) == 3 check_trace(spaces, expected) if __name__ == "__main__": test_cpu_matmul() - test_cpu_matmul_relu() - test_cuda_matmul() - test_cuda_matmul_relu()