Skip to content

Commit

Permalink
[MetaSchedule] Performance Alignment - NRM and SFM (CUDA) (#559)
Browse files Browse the repository at this point in the history
* Add rule cross-thread reduction to tune.py

* Skip undefined objects during simplification

* Fix postproc RewriteUnboundBlock

* Use deep copy in PerStoreFeature

* Update RewriteUnboundBlock

* Use sampling in rule CrossThreadReduction

* Add a not-fusible case

* Support follow-split in rule cross-thread reduction

* Add unittest for trace simplification

* Fix AutoInline

* Add workload SFM
  • Loading branch information
MasterJH5574 authored Dec 27, 2021
1 parent dab6c9c commit dea5038
Show file tree
Hide file tree
Showing 17 changed files with 388 additions and 59 deletions.
3 changes: 2 additions & 1 deletion include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ class ScheduleRule : public runtime::ObjectRef {
/*!
* \brief Create a schedule rule which applies cross-thread reduction to some reduction blocks
* correspondingly when needed
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule CrossThreadReduction();
TVM_DLL static ScheduleRule CrossThreadReduction(Optional<Integer> max_innermost_factor);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The rule created
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/schedule_rule/add_rfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AddRFactor(ScheduleRule):
parallelism, i.e. `num_cores * max_jobs_per_core`.
Use -1 to disable parallelism.
max_innermost_factor: Optional[int] = None
The maximum size of the innermost factor. NullOpt means no limit.
The maximum size of the innermost factor. None means no limit.
"""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@
class CrossThreadReduction(ScheduleRule):
"""A schedule rule which applies cross-thread reduction to some reduction blocks
correspondingly when needed
Parameters
----------
max_innermost_factor: Optional[int] = None
The maximum size of the innermost factor. None means no limit.
"""

def __init__(self) -> None:
def __init__(self, max_innermost_factor: Optional[int] = None) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleCrossThreadReduction, # type: ignore # pylint: disable=no-member
max_innermost_factor,
)
3 changes: 2 additions & 1 deletion python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get(target: Target) -> List[ScheduleRule]:
if target.kind.name == "cuda":
return [
auto_inline(target),
cross_thread_reduction(target),
multi_level_tiling(target),
auto_inline_after_tiling(target),
parallel_vectorize_unroll(target),
Expand Down Expand Up @@ -197,5 +198,5 @@ def add_rfactor(target: Target) -> ScheduleRule:
def cross_thread_reduction(target: Target) -> ScheduleRule:
"""Default schedule rules for with cross-thread reduction"""
if target.kind.name == "cuda":
return CrossThreadReduction()
return CrossThreadReduction(max_innermost_factor=64)
raise NotImplementedError(f"{target.kind.name} is not supported")
9 changes: 9 additions & 0 deletions python/tvm/meta_schedule/testing/te_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,15 @@ def create_te_workload(name: str, idx: int) -> tir.PrimFunc:
(1, 4096, 1024),
],
),
"SFM": (
softmax_mn,
[
(256, 256),
(512, 512),
(1024, 1024),
(2048, 2048),
],
),
"C2d-BN-RELU": (
conv2d_nhwc_bn_relu,
[
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _sch_rules() -> List[ScheduleRule]:
require_ordered=False,
disallow_op=None,
),
M.CrossThreadReduction(max_innermost_factor=64),
M.MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/feature_extractor/per_store_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void {
const auto& candidate = candidates[task_id];
std::vector<std::vector<double>> features;
ExtractSingle(candidate->sch->mod(), is_gpu, &features);
ExtractSingle(DeepCopyIRModule(candidate->sch->mod()), is_gpu, &features);
results[task_id] = tir::utils::AsNDArray(features);
};
support::parallel_for_dynamic(0, candidates.size(), tune_context->num_threads, f);
Expand Down
12 changes: 9 additions & 3 deletions src/meta_schedule/postproc/rewrite_unbound_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) {
int i_block_idx = -1;
int i_thread_idx = -1;
int i_multi_child = -1;
int i_spatial_loop = -1;
for (int i = 0; i < n; ++i) {
const StmtSRef& loop_sref = loops[i];
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
Expand All @@ -65,20 +66,25 @@ BindType GetBindType(const StmtSRef& block_sref, int* fuse_first_num) {
i_multi_child = i + 1;
}
}
if (tir::GetLoopIterType(loop_sref) == IterVarType::kDataPar) {
if (i_spatial_loop == i - 1) {
++i_spatial_loop;
}
}
}
if (i_multi_child == -1) {
i_multi_child = n;
}
if (i_block_idx != -1 && i_thread_idx != -1) {
if ((i_block_idx != -1 && i_thread_idx != -1) || i_spatial_loop == -1) {
return BindType::kNoBind;
} else if (i_block_idx != -1 && i_thread_idx == -1) {
ICHECK(false) << "Unsupported case, where blockIdx is bound but threadIdx is not";
throw;
} else if (i_block_idx == -1 && i_thread_idx != -1) {
*fuse_first_num = std::min(i_multi_child, i_thread_idx);
*fuse_first_num = std::min(std::min(i_multi_child, i_thread_idx), i_spatial_loop + 1);
return BindType::kBindBlock;
} else { // i_block_idx == -1 && i_thread_idx == -1
*fuse_first_num = i_multi_child;
*fuse_first_num = std::min(i_multi_child, i_spatial_loop + 1);
return BindType::kBindBlockThread;
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
// Last cond: Check inline into the spatial consumer or the spatial producer
if (into_consumer) {
Array<tir::StmtSRef> consumer_srefs = GetConsumers(state, block_sref);
if (consumer_srefs.size() == 1 && IsSpatial(consumer_srefs[0])) {
if (!consumer_srefs.empty()) {
if (!into_cache_only ||
tir::GetAnn<Integer>(consumer_srefs[0], tir::attr::meta_schedule_cache_type).defined()) {
if (CanComputeInline(state, block_sref)) {
Expand Down
79 changes: 67 additions & 12 deletions src/meta_schedule/schedule_rule/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,26 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
GetComputeTargetLoopAndBlock(tmp_sch, block_rv);

// Step 3. Try block fusion.
Array<tir::ExprRV> factors{nullptr};
if (fusible) {
ICHECK(target_block.defined());
ICHECK(target_loop.defined());

// Step 3.1. If the outer loops of `target_block` haven't been bound to threadIdx, we should
// first bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to
// split the loop before binding.
// Step 3.1.
// - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first
// bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split
// the loop before binding.
// - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor.
if (!InThreadScope(tmp_sch, target_block)) {
factors = tmp_sch->SamplePerfectTile(tgt_block_innermost_loop, 2, max_innermost_factor);
const Array<tir::LoopRV>& split_res =
tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, Integer(warp_size)});
tmp_sch->Split(tgt_block_innermost_loop, {factors.begin(), factors.end()});
tmp_sch->Bind(split_res[1], "threadIdx.x");
if (tgt_block_innermost_loop.same_as(target_loop)) {
target_loop = split_res[0];
}
} else {
factors = {tir::ExprRV{nullptr}, GetThreadIdxExtentFromTrace(tmp_sch->trace().value())};
}
// Step 3.2. Do the compute-at.
tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true);
Expand All @@ -94,8 +100,10 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
tir::LoopRV fused_reduce_loop;
ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops);
// Step 5. Split the fused reduction loop and bind the inner one to threadIdx.
const Array<tir::LoopRV>& split_res =
tmp_sch->Split(fused_reduce_loop, {NullOpt, Integer(warp_size)});
if (!factors.defined()) {
factors = tmp_sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor);
}
const Array<tir::LoopRV>& split_res = tmp_sch->Split(fused_reduce_loop, {NullOpt, factors[1]});
tmp_sch->Bind(split_res[1], "threadIdx.x");

return {tmp_sch, sch};
Expand All @@ -113,16 +121,53 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
const Array<tir::LoopRV>& axes = sch->GetLoops(block);
for (const tir::LoopRV& loop_rv : axes) {
const tir::For& loop = sch->Get(loop_rv);
if (!loop->thread_binding.defined()) {
continue;
runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get());
if (tir::IsThreadIdx(thread_scope)) {
return true;
}
if (std::string(loop->thread_binding.value()->thread_tag).substr(0, 9) == "threadIdx") {
}
return false;
}

/*!
* \brief Get the ExprRV which used to define the extent of a given loop.
* \param trace The trace of the schedule, where the extent is to be found
* \param loop The loop whose extent is to be found
* \param extent The finding result
* \return Whether the find is successful.
*/
bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop,
tir::ExprRV* extent) {
for (const tir::Instruction& inst : trace->insts) {
if (inst->kind->name == "Split") {
int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin();
CHECK(inst->inputs[1 + i].defined())
<< "ValueError: Extracting an extent which needs inference is not supported so far";
*extent = Downcast<tir::ExprRV>(inst->inputs[1 + i]);
return true;
}
}
return false;
}

/*!
* \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace.
* \param trace The trace of the schedule, where the extent is to be found
* \return The extent of "threadIdx.x" in the input schedule
*/
tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) {
tir::ExprRV extent{nullptr};
for (const tir::Instruction& inst : trace->insts) {
if (inst->kind->name == "Bind" && Downcast<String>(inst->attrs[0]) == "threadIdx.x") {
if (GetLoopRVExtentSource(trace, Downcast<tir::LoopRV>(inst->inputs[0]), &extent)) {
return extent;
}
}
}
CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\"";
throw;
}

/*!
* \brief Get the compute-at target loop and the first block under the target loop.
* \param sch The TensorIR schedule
Expand All @@ -146,11 +191,17 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
}

// Step 3. Calculate the lowest common ancestor of all the consumers.
// - If the lowest common ancestor is a block, either there is only one consumer, or the LCA is
// the scope block, and thereby the target block is the first consumer;
// - If the lowest common ancestor is a block:
// - if there is only one consumer, the target block is that consumer;
// - if there are multiple consumers, they must not share a common loop, and the case is not
// fusible;
// - If the lowest common ancestor is a loop, the target block is also the first consumer.
const tir::StmtSRef& lca_sref =
tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers));
if (consumers.size() > 1 && lca_sref->StmtAs<tir::BlockNode>() != nullptr) {
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
tir::LoopRV{nullptr});
}

// Step 4. Get the outer loops of the target block, and get the compute-at position index.
Array<tir::LoopRV> tgt_block_loops = sch->GetLoops(consumers[0]);
Expand Down Expand Up @@ -198,18 +249,22 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
int max_threads_per_block;
/*! \brief The number of threads per warp */
int warp_size;
/*! \brief The maximum size of the innermost factor. "-1" means no limit */
int max_innermost_factor;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("max_threads_per_block", &max_threads_per_block);
v->Visit("warp_size", &warp_size);
v->Visit("max_innermost_factor", &max_innermost_factor);
}

static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction";
TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode);
};

ScheduleRule ScheduleRule::CrossThreadReduction() {
ScheduleRule ScheduleRule::CrossThreadReduction(Optional<Integer> max_innermost_factor) {
ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>();
n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
return ScheduleRule(n);
}

Expand Down
6 changes: 4 additions & 2 deletions src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,10 @@ Trace TraceNode::Simplified(bool remove_postproc) const {
}
// Add its inputs as "used" ones
for (const ObjectRef& obj : inst->inputs) {
if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() ||
obj->IsInstance<VarNode>()) {
if (!obj.defined()) {
continue;
} else if (obj->IsInstance<BlockRVNode>() || obj->IsInstance<LoopRVNode>() ||
obj->IsInstance<VarNode>()) {
used_rvs.insert(obj.get());
continue;
} else if (obj->IsInstance<PrimExprNode>()) {
Expand Down
1 change: 1 addition & 0 deletions tests/python/meta_schedule/run_ansor_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ run DIL
run GMM
run GRP
run NRM
run SFM
run T2D
# Subgraph
run C2d-BN-RELU
Expand Down
3 changes: 2 additions & 1 deletion tests/python/meta_schedule/run_meta_schedule_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ run DEP
run DIL
run GMM
run GRP
# run NRM
run NRM
run SFM
run T2D
# Subgraph
run C2d-BN-RELU
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _create_context(mod, target) -> TuneContext:


@tvm.script.ir_module
class Before:
class Before_cooperative_fetch:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle) -> None:
A = T.match_buffer(var_A, [512, 512], dtype="float32")
Expand All @@ -58,7 +58,7 @@ def main(var_A: T.handle, var_B: T.handle) -> None:


@tvm.script.ir_module
class After:
class After_cooperative_fetch:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle) -> None:
A = T.match_buffer(var_A, [512, 512], dtype="float32")
Expand All @@ -71,19 +71,70 @@ def main(var_A: T.handle, var_B: T.handle) -> None:
B[vi, vj] = A[vi, vj] + 1.0


@tvm.script.ir_module
class Before_norm_bmn:
@T.prim_func
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None:
C = T.alloc_buffer([1], dtype="float32")
for i0, i1, i2 in T.grid(1, 256, 256):
with T.block("C"):
b, i, j = T.axis.remap("SRR", [i0, i1, i2])
with T.init():
C[b] = T.float32(0)
C[b] = C[b] + A[b, i, j] * A[b, i, j]
for i0 in T.serial(1):
with T.block("D"):
b = T.axis.S(1, i0)
D[b] = T.sqrt(C[b], dtype="float32")


@tvm.script.ir_module
class After_norm_bmn:
@T.prim_func
def main(A: T.Buffer[(1, 256, 256), "float32"], D: T.Buffer[(1,), "float32"]) -> None:
C = T.alloc_buffer([1], dtype="float32")
for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
for i1, i2 in T.grid(256, 256):
with T.block("C"):
b = T.axis.S(1, 0)
i, j = T.axis.remap("RR", [i1, i2])
T.where(i0_fused_0 * 32 + i0_fused_1 < 1)
with T.init():
C[b] = T.float32(0)
C[b] = C[b] + A[b, i, j] * A[b, i, j]
for i0_fused_0 in T.thread_binding(1, thread="blockIdx.x"):
for i0_fused_1 in T.thread_binding(32, thread="threadIdx.x"):
with T.block("D"):
b = T.axis.S(1, 0)
T.where(i0_fused_0 * 32 + i0_fused_1 < 1)
D[b] = T.sqrt(C[b], dtype="float32")


# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on


def test_rewrite_cooperative_fetch():
mod = Before
mod = Before_cooperative_fetch
target = _target()
ctx = _create_context(mod, target)
sch = tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
tvm.ir.assert_structural_equal(sch.mod, After_cooperative_fetch)


def test_rewrite_norm_bmn():
mod = Before_norm_bmn
target = _target()
ctx = _create_context(mod, target)
sch = tir.Schedule(mod, debug_mask="all")
sch.enter_postproc()
assert ctx.postprocs[0].apply(sch)
tvm.ir.assert_structural_equal(sch.mod, After)
tvm.ir.assert_structural_equal(sch.mod, After_norm_bmn)


if __name__ == "__main__":
test_rewrite_cooperative_fetch()
test_rewrite_norm_bmn()
Loading

0 comments on commit dea5038

Please sign in to comment.