Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Performance Alignment - NRM and SFM (CUDA) #559

3 changes: 2 additions & 1 deletion include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ class ScheduleRule : public runtime::ObjectRef {
/*!
* \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks
* correspondingly when needed
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule CrossThreadReduction();
TVM_DLL static ScheduleRule CrossThreadReduction(Optional<Integer> max_innermost_factor);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The rule created
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/schedule_rule/add_rfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AddRFactor(ScheduleRule):
parallelism, i.e. `num_cores * max_jobs_per_core`.
Use -1 to disable parallelism.
max_innermost_factor: Optional[int] = None
The maximum size of the innermost factor. NullOpt means no limit.
The maximum size of the innermost factor. None means no limit.
"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@
class CrossThreadReduction(ScheduleRule):
"""A schedule rule which applies cross-thread reduction to some reduction blocks
correspondingly when needed

Parameters
----------
max_innermost_factor: Optional[int] = None
The maximum size of the innermost factor. None means no limit.
"""

def __init__(self) -> None:
def __init__(self, max_innermost_factor: Optional[int] = None) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleCrossThreadReduction, # type: ignore # pylint: disable=no-member
max_innermost_factor,
)
3 changes: 2 additions & 1 deletion python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get(target: Target) -> List[ScheduleRule]:
if target.kind.name == "cuda":
return [
auto_inline(target),
cross_thread_reduction(target),
multi_level_tiling(target),
auto_inline_after_tiling(target),
parallel_vectorize_unroll(target),
Expand Down Expand Up @@ -197,5 +198,5 @@ def add_rfactor(target: Target) -> ScheduleRule:
def cross_thread_reduction(target: Target) -> ScheduleRule:
"""Default schedule rules for with cross-thread reduction"""
if target.kind.name == "cuda":
return CrossThreadReduction()
return CrossThreadReduction(max_innermost_factor=64)
raise NotImplementedError(f"{target.kind.name} is not supported")
9 changes: 9 additions & 0 deletions python/tvm/meta_schedule/testing/te_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,15 @@ def create_te_workload(name: str, idx: int) -> tir.PrimFunc:
(1, 4096, 1024),
],
),
"SFM": (
softmax_mn,
[
(256, 256),
(512, 512),
(1024, 1024),
(2048, 2048),
],
),
"C2d-BN-RELU": (
conv2d_nhwc_bn_relu,
[
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _sch_rules() -> List[ScheduleRule]:
require_ordered=False,
disallow_op=None,
),
M.CrossThreadReduction(max_innermost_factor=64),
M.MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/feature_extractor/per_store_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void {
const auto& candidate = candidates[task_id];
std::vector<std::vector<double>> features;
ExtractSingle(candidate->sch->mod(), is_gpu, &features);
ExtractSingle(DeepCopyIRModule(candidate->sch->mod()), is_gpu, &features);
results[task_id] = tir::utils::AsNDArray(features);
};
support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f);
Expand Down
12 changes: 9 additions & 3 deletions src/meta_schedule/postproc/rewrite_unbound_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) {
int i_block_idx = -1;
int i_thread_idx = -1;
int i_multi_child = -1;
int i_spatial_loop = -1;
for (int i = 0; i < n; ++i) {
const StmtSRef& loop_sref = loops[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
Expand All @@ -65,20 +66,25 @@ BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) {
i_multi_child = i + 1;
}
}
if (tir::GetLoopIterType(loop_sref) == IterVarType::kDataPar) {
if (i_spatial_loop == i - 1) {
++i_spatial_loop;
}
}
}
if (i_multi_child == -1) {
i_multi_child = n;
}
if (i_block_idx != -1 && i_thread_idx != -1) {
if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) {
return BindType::kNoBind;
} else if (i_block_idx != -1 && i_thread_idx == -1) {
ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not";
throw;
} else if (i_block_idx == -1 && i_thread_idx != -1) {
*fuse_first_num = std::min(i_multi_child, i_thread_idx);
*fuse_first_num = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1);
return BindType::kBindBlock;
} else { // i_block_idx == -1 && i_thread_idx == -1
*fuse_first_num = i_multi_child;
*fuse_first_num = std::min(i_multi_child, i_spatial_loop + 1);
return BindType::kBindBlockThread;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
// Last cond: Check inline into the spatial consumer or the spatial producer
if (into_consumer) {
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
if (consumer_srefs.size() == 1 && IsSpatial(consumer_srefs[0])) {
if (!consumer_srefs.empty()) {
if (!into_cache_only ||
tir::GetAnn<Integer>(consumer_srefs[0], tir::attr::meta_schedule_cache_type).defined()) {
if (CanComputeInline(state, block_sref)) {
Expand Down
79 changes: 67 additions & 12 deletions src/meta_schedule/schedule_rule/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,26 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
GetComputeTargetLoopAndBlock(tmp_sch, block_rv);

// Step 3. Try block fusion.
Array<tir::ExprRV> factors{nullptr};
if (fusible) {
ICHECK(target_block.defined());
ICHECK(target_loop.defined());

// Step 3.1. If the outer loops of `target_block` haven't been bound to threadIdx, we should
// first bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to
// split the loop before binding.
// Step 3.1.
// - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first
// bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split
// the loop before binding.
// - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor.
if (!InThreadScope(tmp_sch, target_block)) {
factors = tmp_sch->SamplePerfectTile(tgt_block_innermost_loop, 2, max_innermost_factor);
const Array<tir::LoopRV>& split_res =
tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, Integer(warp_size)});
tmp_sch->Split(tgt_block_innermost_loop, {factors.begin(), factors.end()});
tmp_sch->Bind(split_res[1], "threadIdx.x");
if (tgt_block_innermost_loop.same_as(target_loop)) {
target_loop = split_res[0];
}
} else {
factors = {tir::ExprRV{nullptr}, GetThreadIdxExtentFromTrace(tmp_sch->trace().value())};
}
// Step 3.2. Do the compute-at.
tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true);
Expand All @@ -94,8 +100,10 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
tir::LoopRV fused_reduce_loop;
ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops);
// Step 5. Split the fused reduction loop and bind the inner one to threadIdx.
const Array<tir::LoopRV>& split_res =
tmp_sch->Split(fused_reduce_loop, {NullOpt, Integer(warp_size)});
if (!factors.defined()) {
factors = tmp_sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor);
}
const Array<tir::LoopRV>& split_res = tmp_sch->Split(fused_reduce_loop, {NullOpt, factors[1]});
tmp_sch->Bind(split_res[1], "threadIdx.x");

return {tmp_sch, sch};
Expand All @@ -113,16 +121,53 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
const Array<tir::LoopRV>& axes = sch->GetLoops(block);
for (const tir::LoopRV& loop_rv : axes) {
const tir::For& loop = sch->Get(loop_rv);
if (!loop->thread_binding.defined()) {
continue;
runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get());
if (tir::IsThreadIdx(thread_scope)) {
return true;
}
if (std::string(loop->thread_binding.value()->thread_tag).substr(0, 9) == "threadIdx") {
}
return false;
}

/*!
* \brief Get the ExprRV which used to define the extent of a given loop.
* \param trace The trace of the schedule, where the extent is to be found
* \param loop The loop whose extent is to be found
* \param extent The finding result
* \return Whether the find is successful.
*/
bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop,
tir::ExprRV* extent) {
for (const tir::Instruction& inst : trace->insts) {
if (inst->kind->name == "Split") {
int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin();
CHECK(inst->inputs[1 + i].defined())
<< "ValueError: Extracting an extent which needs inference is not supported so far";
*extent = Downcast<tir::ExprRV>(inst->inputs[1 + i]);
return true;
}
}
return false;
}

/*!
* \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace.
* \param trace The trace of the schedule, where the extent is to be found
* \return The extent of "threadIdx.x" in the input schedule
*/
tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) {
tir::ExprRV extent{nullptr};
for (const tir::Instruction& inst : trace->insts) {
if (inst->kind->name == "Bind" && Downcast<String>(inst->attrs[0]) == "threadIdx.x") {
if (GetLoopRVExtentSource(trace, Downcast<tir::LoopRV>(inst->inputs[0]), &extent)) {
return extent;
}
}
}
CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\"";
throw;
}

/*!
* \brief Get the compute-at target loop and the first block under the target loop.
* \param sch The TensorIR schedule
Expand All @@ -146,11 +191,17 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
}

// Step 3. Calculate the lowest common ancestor of all the consumers.
// - If the lowest common ancestor is a block, either there is only one consumer, or the LCA is
// the scope block, and thereby the target block is the first consumer;
// - If the lowest common ancestor is a block:
// - if there is only one consumer, the target block is that consumer;
// - if there are multiple consumers, they must not share a common loop, and the case is not
// fusible;
// - If the lowest common ancestor is a loop, the target block is also the first consumer.
const tir::StmtSRef& lca_sref =
tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers));
if (consumers.size() > 1 && lca_sref->StmtAs<tir::BlockNode>() != nullptr) {
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
tir::LoopRV{nullptr});
}

// Step 4. Get the outer loops of the target block, and get the compute-at position index.
Array<tir::LoopRV> tgt_block_loops = sch->GetLoops(consumers[0]);
Expand Down Expand Up @@ -198,18 +249,22 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
int max_threads_per_block;
/*! \brief The number of threads per warp */
int warp_size;
/*! \brief The maximum size of the innermost factor. "-1" means no limit */
int max_innermost_factor;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("max_threads_per_block", &max_threads_per_block);
v->Visit("warp_size", &warp_size);
v->Visit("max_innermost_factor", &max_innermost_factor);
}

static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction";
TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode);
};

ScheduleRule ScheduleRule::CrossThreadReduction() {
ScheduleRule ScheduleRule::CrossThreadReduction(Optional<Integer> max_innermost_factor) {
ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>();
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
return ScheduleRule(n);
}

Expand Down
6 changes: 4 additions & 2 deletions src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,10 @@ Trace TraceNode::Simplified(bool remove_postproc) const {
}
// Add its inputs as "used" ones
for (const ObjectRef& obj : inst->inputs) {
if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() ||
obj->IsInstance<VarNode>()) {
if (!obj.defined()) {
continue;
} else if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() ||
obj->IsInstance<VarNode>()) {
used_rvs.insert(obj.get());
continue;
} else if (obj->IsInstance<PrimExprNode>()) {
Expand Down
1 change: 1 addition & 0 deletions tests/python/meta_schedule/run_ansor_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ run DIL
run GMM
run GRP
run NRM
run SFM
run T2D
# Subgraph
run C2d-BN-RELU
Expand Down
3 changes: 2 additions & 1 deletion tests/python/meta_schedule/run_meta_schedule_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ run DEP
run DIL
run GMM
run GRP
# run NRM
run NRM
run SFM
run T2D
# Subgraph
run C2d-BN-RELU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _create_context(mod, target) -> TuneContext:


@tvm.script.ir_module
class Before:
class Before_cooperative_fetch:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle) -> None:
A = T.match_buffer(var_A, [512, 512], dtype="float32")
Expand All @@ -58,7 +58,7 @@ def main(var_A: T.handle, var_B: T.handle) -> None:


@tvm.script.ir_module
class After:
class After_cooperative_fetch:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle) -> None:
A = T.match_buffer(var_A, [512, 512], dtype="float32")
Expand All @@ -71,19 +71,70 @@ def main(var_A: T.handle, var_B: T.handle) -> None:
B[vi, vj] = A[vi, vj] + 1.0


@tvm.script.ir_module
class Before_norm_bmn:
@T.prim_func
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None:
C = T.alloc_buffer([1], dtype="float32")
for i0, i1, i2 in T.grid(1, 256, 256):
with T.block("C"):
b, i, j = T.axis.remap("SRR", [i0, i1, i2])
with T.init():
C[b] = T.float32(0)
C[b] = C[b] + A[b, i, j] * A[b, i, j]
for i0 in T.serial(1):
with T.block("D"):
b = T.axis.S(1, i0)
D[b] = T.sqrt(C[b], dtype="float32")


@tvm.script.ir_module
class After_norm_bmn:
@T.prim_func
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None:
C = T.alloc_buffer([1], dtype="float32")
for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
for i1, i2 in T.grid(256, 256):
with T.block("C"):
b = T.axis.S(1, 0)
i, j = T.axis.remap("RR", [i1, i2])
T.where(i0_fused_0 * 32 + i0_fused_1 < 1)
with T.init():
C[b] = T.float32(0)
C[b] = C[b] + A[b, i, j] * A[b, i, j]
for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
with T.block("D"):
b = T.axis.S(1, 0)
T.where(i0_fused_0 * 32 + i0_fused_1 < 1)
D[b] = T.sqrt(C[b], dtype="float32")


# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on


def test_rewrite_cooperative_fetch():
mod = Before
mod = Before_cooperative_fetch
target = _target()
ctx = _create_context(mod, target)
sch = tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
tvm.ir.assert_structural_equal(sch.mod, After_cooperative_fetch)


def test_rewrite_norm_bmn():
mod = Before_norm_bmn
target = _target()
ctx = _create_context(mod, target)
sch = tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
tvm.ir.assert_structural_equal(sch.mod, After)
tvm.ir.assert_structural_equal(sch.mod, After_norm_bmn)


if __name__ == "__main__":
test_rewrite_cooperative_fetch()
test_rewrite_norm_bmn()
Loading