Skip to content

Commit

Permalink
finish rfactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Hzfengsy committed Dec 13, 2021
1 parent d7e251e commit 54c3ba0
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 382 deletions.
1 change: 0 additions & 1 deletion include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Integer> vector_load_max_len, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
Expand Down
1 change: 1 addition & 0 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
from .schedule_rule import PyScheduleRule, ScheduleRule
from .add_rfactor import AddRFactor
46 changes: 12 additions & 34 deletions python/tvm/meta_schedule/schedule_rule/add_rfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,47 +25,25 @@

@register_object("meta_schedule.AddRFactor")
class AddRFactor(ScheduleRule):
"""Rule that inlines spatial blocks if it satisfies some conditions
"""Rules for add-rfactor to some blocks if needed.
Parameters
----------
into_producer : bool
If allows to inline a block into its producer
into_consumer : bool
If allows to inline a block into its consumer
into_cache_only : bool
If it only allows to inline into a block generated by cache_read/write
inline_const_tensor : bool
Always inline constant tensors
disallow_if_then_else : bool
Always disallow if-then-else-like constructs
require_injective : bool
Always require the read-to-write mapping to be ordered
require_ordered : bool
Always require the read-to-write mapping to be injective
disallow_op : Optional[List[str]]
The operators that are disallowed in auto inline
max_jobs_per_core: int
The maximum number of jobs to be launched per CPU core. It sets the uplimit of CPU
parallelism, i.e. `num_cores * max_jobs_per_core`.
Use -1 to disable parallelism.
max_innermost_factor: Optional[int] = None
The maximum size of the innermost factor. NullOpt means no limit.
"""

def __init__(
self,
into_producer: bool,
into_consumer: bool,
into_cache_only: bool,
inline_const_tensor: bool,
disallow_if_then_else: bool,
require_injective: bool,
require_ordered: bool,
disallow_op: Optional[List[str]] = None,
max_jobs_per_core: int = 16,
max_innermost_factor: Optional[int] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleAutoInline, # type: ignore # pylint: disable=no-member
into_producer,
into_consumer,
into_cache_only,
inline_const_tensor,
disallow_if_then_else,
require_injective,
require_ordered,
disallow_op,
_ffi_api.ScheduleRuleAddRFactor, # type: ignore # pylint: disable=no-member
max_jobs_per_core,
max_innermost_factor,
)
9 changes: 9 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import List

from tvm.meta_schedule.schedule_rule import (
AddRFactor,
AutoInline,
MultiLevelTiling,
ParallelizeVectorizeUnroll,
Expand All @@ -33,6 +34,7 @@ def get(target: Target) -> List[ScheduleRule]:
if target.kind.name == "llvm":
return [
auto_inline(target),
add_rfactor(target),
multi_level_tiling(target),
parallel_vectorize_unroll(target),
]
Expand Down Expand Up @@ -159,3 +161,10 @@ def random_compute_location(target: Target) -> ScheduleRule:
if target.kind.name == "llvm":
return RandomComputeLocation()
raise NotImplementedError(f"{target.kind.name} is not supported")


def add_rfactor(target: Target) -> ScheduleRule:
"""Default schedule rules for with add_rfactor"""
if target.kind.name == "llvm":
return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64)
raise NotImplementedError(f"{target.kind.name} is not supported")
148 changes: 42 additions & 106 deletions src/meta_schedule/schedule_rule/add_rfactor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,59 +21,6 @@
namespace tvm {
namespace meta_schedule {

/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/

/*!
* \brief Reorder the reduction loops to innermost positions if needed.
* \param sch The schedule
* \param block_rv The block where to apply the reorder
* \param fused_reduce_loop The fusion-generated loop to return.
* \param num_spatial_loops The number of spatial loops to return.
* \note Before invoking this helper function, make sure that the block has only spatial and
* reduction loop axes.
*/
void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv,
tir::LoopRV* fused_reduce_loop, size_t* num_spatial_loops) {
Array<tir::LoopRV> loops = sch->GetLoops(block_rv);
Array<tir::StmtSRef> loop_srefs;
for (const tir::LoopRV& loop_rv : loops) {
loop_srefs.push_back(sch->GetSRef(loop_rv));
}

Array<tir::LoopRV> new_order;
// Step 1. Add spatial loops.
*num_spatial_loops = 0;
for (size_t i = 0; i < loops.size(); ++i) {
if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) {
new_order.push_back(loops[i]);
(*num_spatial_loops)++;
}
}
// Step 2. Add reduction loops.
Array<tir::LoopRV> reduction_loops;
for (size_t i = 0; i < loops.size(); ++i) {
if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) {
new_order.push_back(loops[i]);
reduction_loops.push_back(loops[i]);
}
}
// Step 3. Apply reordering if new_order differs from the original order.
ICHECK_EQ(new_order.size(), loops.size());
for (size_t i = 0; i < loops.size(); ++i) {
if (!new_order[i].same_as(loops[i])) {
sch->Reorder(new_order);
break;
}
}
// Step 4. Fuse all the reduction loops if there are multiple reduction loops.
CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop";
if (reduction_loops.size() > 1) {
*fused_reduce_loop = sch->Fuse(reduction_loops);
} else {
*fused_reduce_loop = reduction_loops[0];
}
}

class AddRFactorNode : public ScheduleRuleNode {
public:
// Inherited from ScheduleRuleNode
Expand All @@ -88,58 +35,7 @@ class AddRFactorNode : public ScheduleRuleNode {
}

// Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
if (!NeedAddRFactor(sch, block_rv)) {
return {sch};
}
// Make a copy of the original schedule.
tir::Schedule ori_sch = sch->Copy();
ori_sch->Seed(sch->ForkSeed());

// Reorder the loop axes if reduction loops are not innermost.
// After the reordering, fuse all the reduction loops.
size_t num_spatial_loops;
tir::LoopRV fused_reduce_loop;
ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops);

// Split the fused reduction loop.
Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor);
const Array<tir::LoopRV>& split_loops =
sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});

Array<tir::Schedule> res;
for (const tir::LoopRV& split_loop : split_loops) {
tir::Schedule sch_tmp = sch->Copy();
sch_tmp->Seed(sch->ForkSeed());
const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops);
Array<tir::LoopRV> axes = sch_tmp->GetLoops(block_rf);
ICHECK_GT(axes.size(), num_spatial_loops);
res.push_back(sch_tmp);
}

res.push_back(ori_sch);
return res;
}

bool NeedAddRFactor(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
tir::ScheduleState state = sch->state();
const tir::BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
// Cond 1. The block has no annotations
if (!block->annotations.empty()) {
return false;
}
// Cond 2. The block has only one write buffer
if (block->writes.size() != 1) {
return false;
}
if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_,
max_parallel_basic_) ||
HasCacheWriteBlock(sch, block_rv, 0)) {
return false;
}
return true;
}
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv);

public:
/*!
Expand All @@ -164,7 +60,7 @@ class AddRFactorNode : public ScheduleRuleNode {

static constexpr const char* _type_key = "meta_schedule.AddRFactor";
TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode);
}; // namespace meta_schedule
};

ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core,
Optional<Integer> max_innermost_factor) {
Expand All @@ -176,5 +72,45 @@ ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core,
return ScheduleRule(n);
}

Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
tir::StmtSRef block_sref = sch->GetSRef(block_rv);
if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_,
max_parallel_basic_)) {
return {sch};
}

// Make a copy of the original schedule.
tir::Schedule ori_sch = sch->Copy();
ori_sch->Seed(sch->ForkSeed());

// Reorder the loop axes if reduction loops are not innermost.
// After the reordering, fuse all the reduction loops.
size_t num_spatial_loops;
tir::LoopRV fused_reduce_loop;
ReorderAndFuseReductionLoops(sch, block_rv, &fused_reduce_loop, &num_spatial_loops);

// Split the fused reduction loop.
Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor);
const Array<tir::LoopRV>& split_loops =
sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});

Array<tir::Schedule> res;
for (const tir::LoopRV& split_loop : split_loops) {
tir::Schedule sch_tmp = sch->Copy();
sch_tmp->Seed(sch->ForkSeed());
const tir::BlockRV& block_rf = sch_tmp->RFactor(split_loop, num_spatial_loops);
Array<tir::LoopRV> axes = sch_tmp->GetLoops(block_rf);
ICHECK_GT(axes.size(), num_spatial_loops);
res.push_back(sch_tmp);
}

res.push_back(ori_sch);
return res;
}

TVM_REGISTER_NODE_TYPE(AddRFactorNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor")
.set_body_typed(ScheduleRule::AddRFactor);

} // namespace meta_schedule
} // namespace tvm
54 changes: 54 additions & 0 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,60 @@ inline Optional<tir::Schedule> ApplyTrace(const IRModule& mod, const tir::Trace&
return sch;
}

/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/

/*!
* \brief Reorder the reduction loops to innermost positions if needed.
* \param sch The schedule
* \param block_rv The block where to apply the reorder
* \param fused_reduce_loop The fusion-generated loop to return.
* \param num_spatial_loops The number of spatial loops to return.
* \note Before invoking this helper function, make sure that the block has only spatial and
* reduction loop axes.
*/
inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv,
tir::LoopRV* fused_reduce_loop,
size_t* num_spatial_loops) {
Array<tir::LoopRV> loops = sch->GetLoops(block_rv);
Array<tir::StmtSRef> loop_srefs;
for (const tir::LoopRV& loop_rv : loops) {
loop_srefs.push_back(sch->GetSRef(loop_rv));
}

Array<tir::LoopRV> new_order;
// Step 1. Add spatial loops.
*num_spatial_loops = 0;
for (size_t i = 0; i < loops.size(); ++i) {
if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) {
new_order.push_back(loops[i]);
(*num_spatial_loops)++;
}
}
// Step 2. Add reduction loops.
Array<tir::LoopRV> reduction_loops;
for (size_t i = 0; i < loops.size(); ++i) {
if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) {
new_order.push_back(loops[i]);
reduction_loops.push_back(loops[i]);
}
}
// Step 3. Apply reordering if new_order differs from the original order.
ICHECK_EQ(new_order.size(), loops.size());
for (size_t i = 0; i < loops.size(); ++i) {
if (!new_order[i].same_as(loops[i])) {
sch->Reorder(new_order);
break;
}
}
// Step 4. Fuse all the reduction loops if there are multiple reduction loops.
CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop";
if (reduction_loops.size() > 1) {
*fused_reduce_loop = sch->Fuse(reduction_loops);
} else {
*fused_reduce_loop = reduction_loops[0];
}
}

} // namespace meta_schedule
} // namespace tvm

Expand Down
10 changes: 0 additions & 10 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,16 +550,6 @@ bool HasOp(const Stmt& stmt, const Array<Op>& ops);
*/
bool HasIfThenElse(const Stmt& stmt);

/*!
* \brief Checks if the given block has cache write blocks.
* \param sch The traced schedule.
* \param block_rv The given block.
* \param write_buffer_index The index of the buffer in block's write region
* \return A boolean indicating whether the block has its cache write block in the trace.
*/
bool HasCacheWriteBlock(const tir::Schedule& sch, const BlockRV& block_rv,
const int& write_buffer_index);

/*!
* \brief Checks if a block could be successfully computed inline into its consumer
* \param self The schedule state
Expand Down
Loading

0 comments on commit 54c3ba0

Please sign in to comment.