forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MetaSchedule] Schedule Rule: Cross Thread Reduction (apache#9994)
Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Xiyou Zhou <xiyou@octoml.ai> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Wuwei Lin <wuwei@apache.org>
- Loading branch information
Showing
9 changed files
with
605 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Rules which apply cross-thread reduction to some reduction blocks correspondingly when needed""" | ||
from typing import List | ||
|
||
from tvm._ffi import register_object | ||
|
||
from .. import _ffi_api | ||
from .schedule_rule import ScheduleRule | ||
|
||
|
||
@register_object("meta_schedule.CrossThreadReduction") | ||
class CrossThreadReduction(ScheduleRule): | ||
"""A schedule rule which applies cross-thread reduction to some reduction blocks | ||
correspondingly when needed | ||
Parameters | ||
---------- | ||
thread_extents: List[int] | ||
Candidates of thread axis extent (values are required to be positive). | ||
""" | ||
|
||
def __init__(self, thread_extents: List[int]) -> None: | ||
self.__init_handle_by_constructor__( | ||
_ffi_api.ScheduleRuleCrossThreadReduction, # type: ignore # pylint: disable=no-member | ||
thread_extents, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
285 changes: 285 additions & 0 deletions
285
src/meta_schedule/schedule_rule/cross_thread_reduction.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,285 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
#include "../utils.h" | ||
|
||
namespace tvm { | ||
namespace meta_schedule { | ||
|
||
class CrossThreadReductionNode : public ScheduleRuleNode { | ||
public: | ||
// Inherited from ScheduleRuleNode | ||
void InitializeWithTuneContext(const TuneContext& context) final { | ||
ICHECK(context->target.defined()); | ||
Target target = context->target.value(); | ||
|
||
Optional<Integer> opt_max_threads_per_block = target->GetAttr<Integer>("max_threads_per_block"); | ||
Optional<Integer> opt_warp_size = target->GetAttr<Integer>("thread_warp_size"); | ||
|
||
if (!opt_max_threads_per_block.defined()) { | ||
LOG(WARNING) << "Target does not have attribute \"max_threads_per_block\", therefore the " | ||
"rule CrossThreadReduction will not be applied"; | ||
} | ||
if (!opt_warp_size.defined()) { | ||
LOG(WARNING) << "Target does not have attribute \"thread_warp_size\", therefore the rule " | ||
"CrossThreadReduction will not be applied"; | ||
} | ||
max_threads_per_block = opt_max_threads_per_block.value_or(Integer(-1))->value; | ||
warp_size = opt_warp_size.value_or(Integer(-1))->value; | ||
} | ||
|
||
// Inherited from ScheduleRuleNode | ||
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { | ||
// Step 0. Check the conditions of this rule. | ||
if (max_threads_per_block == -1 || warp_size == -1) { | ||
return {sch}; | ||
} | ||
const tir::StmtSRef& block_sref = sch->GetSRef(block_rv); | ||
if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_threads_per_block, | ||
warp_size)) { | ||
return {sch}; | ||
} | ||
|
||
// Step 1. Make a copy of the original schedule. The new copy is used for scheduling. | ||
tir::Schedule tmp_sch = sch->Copy(); | ||
tmp_sch->Seed(sch->ForkSeed()); | ||
|
||
// Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the | ||
// block to its consumers. We want to fuse as much as possible because it results in | ||
// significantly faster schedule. | ||
bool fusible = false; | ||
// `target_loop` is the loop position where the input block will be computed at. | ||
tir::LoopRV target_loop{nullptr}; | ||
// `target_block` is the consumer block that we want to compute-at the input block to. | ||
tir::BlockRV target_block{nullptr}; | ||
// `tgt_block_innermost_loop` is the innermost loop outside the target block. | ||
tir::LoopRV tgt_block_innermost_loop{nullptr}; | ||
|
||
std::tie(fusible, target_loop, target_block, tgt_block_innermost_loop) = | ||
GetComputeTargetLoopAndBlock(tmp_sch, block_rv); | ||
|
||
// Step 3. Try block fusion. | ||
int n_candidate = static_cast<int>(thread_extents.size()); | ||
Array<FloatImm> probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate)); | ||
tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); | ||
if (fusible) { | ||
ICHECK(target_block.defined()); | ||
ICHECK(target_loop.defined()); | ||
|
||
// Step 3.1. | ||
// - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first | ||
// bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split | ||
// the loop before binding. | ||
// - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. | ||
if (!InThreadScope(tmp_sch, target_block)) { | ||
const Array<tir::LoopRV>& split_res = | ||
tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, thread_extent}); | ||
tmp_sch->Bind(split_res[1], "threadIdx.x"); | ||
if (tgt_block_innermost_loop.same_as(target_loop)) { | ||
target_loop = split_res[0]; | ||
} | ||
} else { | ||
thread_extent = GetThreadIdxExtentFromTrace(tmp_sch->trace().value()); | ||
} | ||
// Step 3.2. Do the compute-at. | ||
tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true); | ||
// Step 3.3. Set the storage scope of the output buffer to shared memory. | ||
tmp_sch->SetScope(block_rv, /*buffer_index=*/0, /*storage_scope=*/"shared"); | ||
} | ||
|
||
// Step 4. Reorder the loop axes if reduction loops are not innermost. After the reordering, | ||
// fuse all the reduction loops. | ||
size_t num_spatial_loops; | ||
tir::LoopRV fused_reduce_loop; | ||
ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops); | ||
// Step 5. Split the fused reduction loop and bind the inner one to threadIdx. | ||
const Array<tir::LoopRV>& split_res = | ||
tmp_sch->Split(fused_reduce_loop, {NullOpt, thread_extent}); | ||
tmp_sch->Bind(split_res[1], "threadIdx.x"); | ||
|
||
return {tmp_sch, sch}; | ||
} | ||
|
||
private: | ||
/*! | ||
* \brief Check whether the input block is in thread scope, i.e., some of its outer loop is | ||
* bound to threadIdx. | ||
* \param sch The TensorIR schedule | ||
* \param block The block to be checked | ||
* \return A boolean indicating whether the block is in thread scope. | ||
*/ | ||
bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) { | ||
const Array<tir::LoopRV>& axes = sch->GetLoops(block); | ||
for (const tir::LoopRV& loop_rv : axes) { | ||
const tir::For& loop = sch->Get(loop_rv); | ||
runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get()); | ||
if (tir::IsThreadIdx(thread_scope)) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
/*! | ||
* \brief Get the ExprRV which used to define the extent of a given loop. | ||
* \param trace The trace of the schedule, where the extent is to be found | ||
* \param loop The loop whose extent is to be found | ||
* \param extent The finding result | ||
* \return Whether the find is successful. | ||
*/ | ||
bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop, | ||
tir::ExprRV* extent) { | ||
for (const tir::Instruction& inst : trace->insts) { | ||
if (inst->kind->name == "Split") { | ||
int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin(); | ||
CHECK(inst->inputs[1 + i].defined()) | ||
<< "ValueError: Extracting an extent which needs inference is not supported so far"; | ||
*extent = Downcast<tir::ExprRV>(inst->inputs[1 + i]); | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
/*! | ||
* \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace. | ||
* \param trace The trace of the schedule, where the extent is to be found | ||
* \return The extent of "threadIdx.x" in the input schedule | ||
*/ | ||
tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { | ||
tir::ExprRV extent{nullptr}; | ||
for (const tir::Instruction& inst : trace->insts) { | ||
if (inst->kind->name == "Bind" && Downcast<String>(inst->attrs[0]) == "threadIdx.x") { | ||
if (GetLoopRVExtentSource(trace, Downcast<tir::LoopRV>(inst->inputs[0]), &extent)) { | ||
return extent; | ||
} | ||
} | ||
} | ||
CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\""; | ||
throw; | ||
} | ||
|
||
/*! | ||
* \brief Get the compute-at target loop and the first block under the target loop. | ||
* \param sch The TensorIR schedule | ||
* \param block_rv The block whose compute-at target loop is queried | ||
* \return A tuple consisting of | ||
* 1. a boolean indicating whether the block can be computed at some target loop (a.k.a. fusible); | ||
* 2. the compute-at target loop when fusible, or a null loop random variable; | ||
* 3. the first block under the target loop when fusible, or a null block random variable; | ||
* 4. the innermost loop outside the target block when fusible, or a null block random variable. | ||
*/ | ||
std::tuple<bool, tir::LoopRV, tir::BlockRV, tir::LoopRV> GetComputeTargetLoopAndBlock( | ||
const tir::Schedule& sch, const tir::BlockRV& block_rv) { | ||
// Step 1. Get all the consumers of the input block. | ||
Array<tir::BlockRV> consumers = sch->GetConsumers(block_rv); | ||
|
||
// Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is | ||
// not fusible. | ||
if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { | ||
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, | ||
tir::LoopRV{nullptr}); | ||
} | ||
|
||
// Step 3. Calculate the lowest common ancestor of all the consumers. | ||
// - If the lowest common ancestor is a block: | ||
// - if there is only one consumer, the target block is that consumer; | ||
// - if there are multiple consumers, they must not share a common loop, and the case is not | ||
// fusible; | ||
// - If the lowest common ancestor is a loop, the target block is also the first consumer. | ||
const tir::StmtSRef& lca_sref = | ||
tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers)); | ||
if (consumers.size() > 1 && lca_sref->StmtAs<tir::BlockNode>() != nullptr) { | ||
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, | ||
tir::LoopRV{nullptr}); | ||
} | ||
|
||
// Step 4. Get the outer loops of the target block, and get the compute-at position index. | ||
Array<tir::LoopRV> tgt_block_loops = sch->GetLoops(consumers[0]); | ||
int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref); | ||
|
||
// Step 5. A negative position index means not fusible, and vice-versa. | ||
if (pos < 0) { | ||
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, | ||
tir::LoopRV{nullptr}); | ||
} else { | ||
return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); | ||
} | ||
} | ||
|
||
/*! | ||
* \brief Get the compute-at position index of the input block, according to | ||
* 1. the loops outside the input block; | ||
* 2. the loops outside the target block; | ||
* 3. the lowest common ancestor of all the consumers of the input block. | ||
* \param sch The TensorIR schedule | ||
* \param block_loops The loops outside the input block | ||
* \param tgt_block_loops The loops outside the target block | ||
* \param lca_sref The lowest common ancestor of all the consumers of the input block | ||
* \return The compute-at position index of the input block | ||
*/ | ||
int GetComputePosition(const tir::Schedule& sch, const Array<tir::LoopRV>& block_loops, | ||
const Array<tir::LoopRV>& tgt_block_loops, const tir::StmtSRef& lca_sref) { | ||
int n_block_loop = static_cast<int>(block_loops.size()); | ||
int n_tgt_block_loop = static_cast<int>(tgt_block_loops.size()); | ||
|
||
for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) { | ||
if (tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) { | ||
return i - 1; | ||
} else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) { | ||
// If the lowest common ancestor is a loop, the compute location of the input block should | ||
// not be deeper than the LCA loop. | ||
return i; | ||
} | ||
} | ||
return std::min(n_block_loop, n_tgt_block_loop) - 1; | ||
} | ||
|
||
public: | ||
/*! \brief The maximum number of threads allowed in a thread block */ | ||
int max_threads_per_block; | ||
/*! \brief The number of threads per warp */ | ||
int warp_size; | ||
/*! \brief Candidates of thread axis extent (values are required to be positive). */ | ||
Array<Integer> thread_extents; | ||
|
||
void VisitAttrs(tvm::AttrVisitor* v) { | ||
v->Visit("max_threads_per_block", &max_threads_per_block); | ||
v->Visit("warp_size", &warp_size); | ||
v->Visit("thread_extents", &thread_extents); | ||
} | ||
|
||
static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); | ||
}; | ||
|
||
ScheduleRule ScheduleRule::CrossThreadReduction(Array<Integer> thread_extents) { | ||
for (const Integer& extent : thread_extents) { | ||
CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; | ||
} | ||
ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>(); | ||
n->thread_extents = std::move(thread_extents); | ||
return ScheduleRule(n); | ||
} | ||
|
||
TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); | ||
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") | ||
.set_body_typed(ScheduleRule::CrossThreadReduction); | ||
|
||
} // namespace meta_schedule | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.