diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 95fce13df02f..b3a4f78385c1 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -162,6 +162,13 @@ class ScheduleRule : public runtime::ObjectRef { */ TVM_DLL static ScheduleRule AddRFactor(int max_jobs_per_core, // Optional max_innermost_factor); + /*! + * \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks + * correspondingly when needed + * \param thread_extents Candidates of thread axis extent (values are required to be positive). + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The rule created diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index 475c43a3fda1..c54eecf8b835 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -18,5 +18,6 @@ """ from .add_rfactor import AddRFactor from .auto_inline import AutoInline +from .cross_thread_reduction import CrossThreadReduction from .schedule_rule import PyScheduleRule, ScheduleRule from .random_compute_location import RandomComputeLocation diff --git a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py new file mode 100644 index 000000000000..f242e42aea4b --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Rules which apply cross-thread reduction to some reduction blocks correspondingly when needed""" +from typing import List + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +@register_object("meta_schedule.CrossThreadReduction") +class CrossThreadReduction(ScheduleRule): + """A schedule rule which applies cross-thread reduction to some reduction blocks + correspondingly when needed + + Parameters + ---------- + thread_extents: List[int] + Candidates of thread axis extent (values are required to be positive). + """ + + def __init__(self, thread_extents: List[int]) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleCrossThreadReduction, # type: ignore # pylint: disable=no-member + thread_extents, + ) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index 020869da4b10..b9606eed0eb4 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -18,6 +18,7 @@ from tvm.meta_schedule.schedule_rule import ( AddRFactor, AutoInline, + CrossThreadReduction, ScheduleRule, ) from tvm.target import Target @@ -53,3 +54,10 @@ def add_rfactor(target: Target) -> ScheduleRule: if target.kind.name == "llvm": return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64) raise NotImplementedError(f"{target.kind.name} is not supported") + + +def cross_thread_reduction(target: Target) -> ScheduleRule: + """Default schedule rules for with cross-thread reduction""" + if target.kind.name == "cuda": + return CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) + raise NotImplementedError(f"{target.kind.name} is not supported") diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc new file mode 100644 index 000000000000..0c8546ccfcdd --- /dev/null +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +class CrossThreadReductionNode : public ScheduleRuleNode { + public: + // Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + + Optional opt_max_threads_per_block = target->GetAttr("max_threads_per_block"); + Optional opt_warp_size = target->GetAttr("thread_warp_size"); + + if (!opt_max_threads_per_block.defined()) { + LOG(WARNING) << "Target does not have attribute \"max_threads_per_block\", therefore the " + "rule CrossThreadReduction will not be applied"; + } + if (!opt_warp_size.defined()) { + LOG(WARNING) << "Target does not have attribute \"thread_warp_size\", therefore the rule " + "CrossThreadReduction will not be applied"; + } + max_threads_per_block = opt_max_threads_per_block.value_or(Integer(-1))->value; + warp_size = opt_warp_size.value_or(Integer(-1))->value; + } + + // Inherited from ScheduleRuleNode + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { + // Step 0. Check the conditions of this rule. + if (max_threads_per_block == -1 || warp_size == -1) { + return {sch}; + } + const tir::StmtSRef& block_sref = sch->GetSRef(block_rv); + if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_threads_per_block, + warp_size)) { + return {sch}; + } + + // Step 1. Make a copy of the original schedule. The new copy is used for scheduling. + tir::Schedule tmp_sch = sch->Copy(); + tmp_sch->Seed(sch->ForkSeed()); + + // Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the + // block to its consumers. We want to fuse as much as possible because it results in + // significantly faster schedule. + bool fusible = false; + // `target_loop` is the loop position where the input block will be computed at. + tir::LoopRV target_loop{nullptr}; + // `target_block` is the consumer block that we want to compute-at the input block to. + tir::BlockRV target_block{nullptr}; + // `tgt_block_innermost_loop` is the innermost loop outside the target block. + tir::LoopRV tgt_block_innermost_loop{nullptr}; + + std::tie(fusible, target_loop, target_block, tgt_block_innermost_loop) = + GetComputeTargetLoopAndBlock(tmp_sch, block_rv); + + // Step 3. Try block fusion. + int n_candidate = static_cast(thread_extents.size()); + Array probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate)); + tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); + if (fusible) { + ICHECK(target_block.defined()); + ICHECK(target_loop.defined()); + + // Step 3.1. + // - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first + // bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split + // the loop before binding. + // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. + if (!InThreadScope(tmp_sch, target_block)) { + const Array& split_res = + tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, thread_extent}); + tmp_sch->Bind(split_res[1], "threadIdx.x"); + if (tgt_block_innermost_loop.same_as(target_loop)) { + target_loop = split_res[0]; + } + } else { + thread_extent = GetThreadIdxExtentFromTrace(tmp_sch->trace().value()); + } + // Step 3.2. Do the compute-at. + tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true); + // Step 3.3. Set the storage scope of the output buffer to shared memory. + tmp_sch->SetScope(block_rv, /*buffer_index=*/0, /*storage_scope=*/"shared"); + } + + // Step 4. Reorder the loop axes if reduction loops are not innermost. After the reordering, + // fuse all the reduction loops. + size_t num_spatial_loops; + tir::LoopRV fused_reduce_loop; + ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops); + // Step 5. Split the fused reduction loop and bind the inner one to threadIdx. + const Array& split_res = + tmp_sch->Split(fused_reduce_loop, {NullOpt, thread_extent}); + tmp_sch->Bind(split_res[1], "threadIdx.x"); + + return {tmp_sch, sch}; + } + + private: + /*! + * \brief Check whether the input block is in thread scope, i.e., some of its outer loop is + * bound to threadIdx. + * \param sch The TensorIR schedule + * \param block The block to be checked + * \return A boolean indicating whether the block is in thread scope. + */ + bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) { + const Array& axes = sch->GetLoops(block); + for (const tir::LoopRV& loop_rv : axes) { + const tir::For& loop = sch->Get(loop_rv); + runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get()); + if (tir::IsThreadIdx(thread_scope)) { + return true; + } + } + return false; + } + + /*! + * \brief Get the ExprRV which used to define the extent of a given loop. + * \param trace The trace of the schedule, where the extent is to be found + * \param loop The loop whose extent is to be found + * \param extent The finding result + * \return Whether the find is successful. + */ + bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop, + tir::ExprRV* extent) { + for (const tir::Instruction& inst : trace->insts) { + if (inst->kind->name == "Split") { + int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin(); + CHECK(inst->inputs[1 + i].defined()) + << "ValueError: Extracting an extent which needs inference is not supported so far"; + *extent = Downcast(inst->inputs[1 + i]); + return true; + } + } + return false; + } + + /*! + * \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace. + * \param trace The trace of the schedule, where the extent is to be found + * \return The extent of "threadIdx.x" in the input schedule + */ + tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { + tir::ExprRV extent{nullptr}; + for (const tir::Instruction& inst : trace->insts) { + if (inst->kind->name == "Bind" && Downcast(inst->attrs[0]) == "threadIdx.x") { + if (GetLoopRVExtentSource(trace, Downcast(inst->inputs[0]), &extent)) { + return extent; + } + } + } + CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\""; + throw; + } + + /*! + * \brief Get the compute-at target loop and the first block under the target loop. + * \param sch The TensorIR schedule + * \param block_rv The block whose compute-at target loop is queried + * \return A tuple consisting of + * 1. a boolean indicating whether the block can be computed at some target loop (a.k.a. fusible); + * 2. the compute-at target loop when fusible, or a null loop random variable; + * 3. the first block under the target loop when fusible, or a null block random variable; + * 4. the innermost loop outside the target block when fusible, or a null block random variable. + */ + std::tuple GetComputeTargetLoopAndBlock( + const tir::Schedule& sch, const tir::BlockRV& block_rv) { + // Step 1. Get all the consumers of the input block. + Array consumers = sch->GetConsumers(block_rv); + + // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is + // not fusible. + if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, + tir::LoopRV{nullptr}); + } + + // Step 3. Calculate the lowest common ancestor of all the consumers. + // - If the lowest common ancestor is a block: + // - if there is only one consumer, the target block is that consumer; + // - if there are multiple consumers, they must not share a common loop, and the case is not + // fusible; + // - If the lowest common ancestor is a loop, the target block is also the first consumer. + const tir::StmtSRef& lca_sref = + tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers)); + if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, + tir::LoopRV{nullptr}); + } + + // Step 4. Get the outer loops of the target block, and get the compute-at position index. + Array tgt_block_loops = sch->GetLoops(consumers[0]); + int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref); + + // Step 5. A negative position index means not fusible, and vice-versa. + if (pos < 0) { + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, + tir::LoopRV{nullptr}); + } else { + return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); + } + } + + /*! + * \brief Get the compute-at position index of the input block, according to + * 1. the loops outside the input block; + * 2. the loops outside the target block; + * 3. the lowest common ancestor of all the consumers of the input block. + * \param sch The TensorIR schedule + * \param block_loops The loops outside the input block + * \param tgt_block_loops The loops outside the target block + * \param lca_sref The lowest common ancestor of all the consumers of the input block + * \return The compute-at position index of the input block + */ + int GetComputePosition(const tir::Schedule& sch, const Array& block_loops, + const Array& tgt_block_loops, const tir::StmtSRef& lca_sref) { + int n_block_loop = static_cast(block_loops.size()); + int n_tgt_block_loop = static_cast(tgt_block_loops.size()); + + for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) { + if (tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) { + return i - 1; + } else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) { + // If the lowest common ancestor is a loop, the compute location of the input block should + // not be deeper than the LCA loop. + return i; + } + } + return std::min(n_block_loop, n_tgt_block_loop) - 1; + } + + public: + /*! \brief The maximum number of threads allowed in a thread block */ + int max_threads_per_block; + /*! \brief The number of threads per warp */ + int warp_size; + /*! \brief Candidates of thread axis extent (values are required to be positive). */ + Array thread_extents; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("max_threads_per_block", &max_threads_per_block); + v->Visit("warp_size", &warp_size); + v->Visit("thread_extents", &thread_extents); + } + + static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction"; + TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); +}; + +ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { + for (const Integer& extent : thread_extents) { + CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; + } + ObjectPtr n = make_object(); + n->thread_extents = std::move(thread_extents); + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") + .set_body_typed(ScheduleRule::CrossThreadReduction); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 2053f8ddde93..be5e55d4ec70 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1788,8 +1788,9 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, // /*require_stage_pipeline=*/false, // /*require_subtree_compact_dataflow=*/false); - if (!(IsReductionBlock(self, block_sref, scope_sref) && // - IsTrivialBinding(self, block_sref))) { + if (!IsReductionBlock(self, block_sref, scope_sref) // + || !IsTrivialBinding(self, block_sref) // + || HasBeenMultiLevelTiled(block_sref)) { return false; } diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index dc05f10cc4f8..1021cbf2c625 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -448,8 +448,10 @@ Trace TraceNode::Simplified(bool remove_postproc) const { } // Add its inputs as "used" ones for (const ObjectRef& obj : inst->inputs) { - if (obj->IsInstance() || obj->IsInstance() || - obj->IsInstance()) { + if (!obj.defined()) { + continue; + } else if (obj->IsInstance() || obj->IsInstance() || + obj->IsInstance()) { used_rvs.insert(obj.get()); continue; } else if (obj->IsInstance()) { diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index ebd2284cbe3c..bb34c6aadaba 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -120,6 +120,21 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { return loops; } +/*! + * \brief Convert an array of block rvs to an array of block StmtSRefs + * \param sch The schedule used to evaluate the random variables + * \param block_rvs The random variables to be converted + * \return The conversion result srefs + */ +inline Array BlockRVs2StmtSRefs(const Schedule& sch, const Array& block_rvs) { + Array block_srefs; + block_srefs.reserve(block_rvs.size()); + for (const BlockRV& block_rv : block_rvs) { + block_srefs.push_back(sch->GetSRef(block_rv)); + } + return block_srefs; +} + /******** Storage scope ********/ /*! diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py new file mode 100644 index 000000000000..7bed18b0f9ea --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.schedule_rule import cross_thread_reduction +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.target import Target +from tvm.te.operation import create_prim_func + +import tvm +from tvm.script import tir as T + + +@tvm.script.ir_module +class Softmax_mn_after_inline: + @T.prim_func + def main( + A: T.Buffer[(256, 256), "float32"], T_softmax_norm: T.Buffer[(256, 256), "float32"] + ) -> None: + T_softmax_maxelem = T.alloc_buffer([256], dtype="float32") + T_softmax_expsum = T.alloc_buffer([256], dtype="float32") + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_maxelem"): + i0_1, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_maxelem[i0_1] = T.min_value("float32") + T_softmax_maxelem[i0_1] = T.max(T_softmax_maxelem[i0_1], A[i0_1, k]) + for i0, i1 in T.grid(256, 256): + with T.block("T_softmax_expsum"): + i0_2, k = T.axis.remap("SR", [i0, i1]) + with T.init(): + T_softmax_expsum[i0_2] = T.float32(0) + T_softmax_expsum[i0_2] = T_softmax_expsum[i0_2] + T.exp( + A[i0_2, k] - T_softmax_maxelem[i0_2], dtype="float32" + ) + for i0_3, i1 in T.grid(256, 256): + with T.block("T_softmax_norm"): + i0_4, i1_1 = T.axis.remap("SS", [i0_3, i1]) + T.block_attr({"axis": 1}) + T_softmax_norm[i0_4, i1_1] = ( + T.exp(A[i0_4, i1_1] - T_softmax_maxelem[i0_4], dtype="float32") + / T_softmax_expsum[i0_4] + ) + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_gpu_softmax_mn(): + expected = [ + [], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3 = sch.get_loops(block=b1)", + "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l7, l8, l9 = sch.get_loops(block=b0)", + "l10, l11 = sch.split(loop=l9, factors=[None, v4])", + 'sch.bind(loop=l11, thread_axis="threadIdx.x")', + ], + [ + 'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3 = sch.get_loops(block=b1)", + "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l7, l8, l9 = sch.get_loops(block=b0)", + "l10, l11 = sch.split(loop=l9, factors=[None, v4])", + 'sch.bind(loop=l11, thread_axis="threadIdx.x")', + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', + "b2, = sch.get_consumers(block=b1)", + "l3, l4 = sch.get_loops(block=b2)", + "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l6, l7 = sch.split(loop=l4, factors=[None, v5])", + 'sch.bind(loop=l7, thread_axis="threadIdx.x")', + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)", + 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', + "l8, l9, l10 = sch.get_loops(block=b1)", + "l11, l12 = sch.split(loop=l10, factors=[None, v5])", + 'sch.bind(loop=l12, thread_axis="threadIdx.x")', + "b13, = sch.get_consumers(block=b0)", + "l14, l15 = sch.get_loops(block=b13)", + "v16 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l17, l18 = sch.split(loop=l15, factors=[None, v16])", + 'sch.bind(loop=l18, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l14, preserve_unit_loops=1)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l19, l20, l21 = sch.get_loops(block=b0)", + "l22, l23 = sch.split(loop=l21, factors=[None, v16])", + 'sch.bind(loop=l23, thread_axis="threadIdx.x")', + ], + ] + target = Target("nvidia/geforce-rtx-3090", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.softmax_mn( + n=256, + m=256, + ) + ), + target=target, + rule=cross_thread_reduction(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 4 + check_trace(spaces, expected) + + +def test_gpu_softmax_mn_after_inline(): + expected = [ + [], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + "v1 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l2, l3 = sch.get_loops(block=b0)", + "l4, l5 = sch.split(loop=l3, factors=[None, v1])", + 'sch.bind(loop=l5, thread_axis="threadIdx.x")', + ], + [ + 'b0 = sch.get_block(name="T_softmax_expsum", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, l3 = sch.get_loops(block=b1)", + "v4 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l5, l6 = sch.split(loop=l3, factors=[None, v4])", + 'sch.bind(loop=l6, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l2, preserve_unit_loops=1)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l7, l8, l9 = sch.get_loops(block=b0)", + "l10, l11 = sch.split(loop=l9, factors=[None, v4])", + 'sch.bind(loop=l11, thread_axis="threadIdx.x")', + ], + [ + 'b0 = sch.get_block(name="T_softmax_maxelem", func_name="main")', + 'b1 = sch.get_block(name="T_softmax_expsum", func_name="main")', + "b2, = sch.get_consumers(block=b1)", + "l3, l4 = sch.get_loops(block=b2)", + "v5 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l6, l7 = sch.split(loop=l4, factors=[None, v5])", + 'sch.bind(loop=l7, thread_axis="threadIdx.x")', + "sch.compute_at(block=b1, loop=l3, preserve_unit_loops=1)", + 'sch.set_scope(block=b1, buffer_index=0, storage_scope="shared")', + "l8, l9, l10 = sch.get_loops(block=b1)", + "l11, l12 = sch.split(loop=l10, factors=[None, v5])", + 'sch.bind(loop=l12, thread_axis="threadIdx.x")', + "b13, b14 = sch.get_consumers(block=b0)", + "l15, l16, l17, l18 = sch.get_loops(block=b13)", + "sch.compute_at(block=b0, loop=l15, preserve_unit_loops=1)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l19, l20, l21 = sch.get_loops(block=b0)", + "l22, l23 = sch.split(loop=l21, factors=[None, v5])", + 'sch.bind(loop=l23, thread_axis="threadIdx.x")', + ], + ] + target = Target("nvidia/geforce-rtx-3090", host="llvm") + ctx = _create_context( + mod=Softmax_mn_after_inline, + target=target, + rule=cross_thread_reduction(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 4 + check_trace(spaces, expected) + + +def test_gpu_batch_norm_bmn(): + expected = [ + [], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + "b1, = sch.get_consumers(block=b0)", + "l2, = sch.get_loops(block=b1)", + "v3 = sch.sample_categorical(candidates=[4, 8, 16, 32, 64, 128, 256, 512], probs=[0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125])", + "l4, l5 = sch.split(loop=l2, factors=[None, v3])", + 'sch.bind(loop=l5, thread_axis="threadIdx.x")', + "sch.compute_at(block=b0, loop=l4, preserve_unit_loops=1)", + 'sch.set_scope(block=b0, buffer_index=0, storage_scope="shared")', + "l6, l7, l8, l9 = sch.get_loops(block=b0)", + "l10 = sch.fuse(l8, l9)", + "l11, l12 = sch.split(loop=l10, factors=[None, v3])", + 'sch.bind(loop=l12, thread_axis="threadIdx.x")', + ], + ] + target = Target("nvidia/geforce-rtx-3090", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.norm_bmn( + B=1, + M=512, + N=512, + ) + ), + target=target, + rule=cross_thread_reduction(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 2 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_gpu_softmax_mn() + test_gpu_softmax_mn_after_inline() + test_gpu_batch_norm_bmn()