Skip to content

Commit

Permalink
[TIR] Add merge primitive for TIR schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
yincs-intellif committed Mar 27, 2023
1 parent 0d0d2f0 commit 2c1bb3d
Show file tree
Hide file tree
Showing 10 changed files with 514 additions and 0 deletions.
10 changes: 10 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,16 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
/******** Schedule: Transform loops ********/
/*!
* \brief Merge a list of loops into one. The loops under their LCA requires:
* 1) Under the same scope.
* 2) Can't have annotations or thread bindings
* 3) Start with 0 and have same domain.
* 4) The inner loop must be the only child of the outer loop.
* \param loop_rvs The loops to the loops to be merged
* \return The new loop after merge
*/
virtual LoopRV Merge(const Array<LoopRV>& loop_rvs) = 0;
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
Expand Down
78 changes: 78 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,84 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]:
return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member

########## Schedule: Transform loops ##########
@type_checked
def merge(
self,
*loops: List[LoopRV],
) -> LoopRV:
"""Merge a list of loops into one. The loops under their LCA requires:
1) Under the same scope
2) Can't have annotations or thread bindings.
3) Start with 0 and have same domain
4) The inner loop must be the only child of the outer loop.
Parameters
----------
*loops : List[LoopRV]
The loops to be merged
Returns
-------
fused_loop : LoopRV
The new loop after merge
Examples
--------
Before applying merge, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_merge(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
.. code-block:: python
sch = tir.Schedule(before_fuse)
i1, _ = sch.get_loops(sch.get_block("B"))
i2, _ = sch.get_loops(sch.get_block("C"))
sch.merge(i1, i2)
print(sch.mod["main"].script())
After applying fuse, the IR becomes:
.. code-block:: python
@T.prim_func
def after_fuse(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
C = T.match_buffer(c, (128, 128))
# the 2 loops are merged into 1
for i_m in range(128):
for j in range(128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i_m, j])
T.reads(A[vi, vj])
T.writes(B[vi, vj])
B[vi, vj] = A[vi, vj] * T.float32(2)
for j in range(128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i_m, j])
T.reads(A[vi, vj])
T.writes(C[vi, vj])
C[vi, vj] = A[vi, vj] * T.float32(2)
"""
return _ffi_api.ScheduleMerge(self, loops) # type: ignore # pylint: disable=no-member

@type_checked
def fuse(
self,
Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,17 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {

/******** Schedule: Transform loops ********/

LoopRV ConcreteScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
CHECK(!loop_rvs.empty()) << "ValueError: 'merge' requires at least 1 loop(s)";
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::Merge(state_, loop_srefs);
TVM_TIR_SCHEDULE_END("merge", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(result);
}

LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) {
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class ConcreteScheduleNode : public ScheduleNode {
Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
LoopRV Merge(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors,
bool preserve_unit_iters) override;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) override;
Expand Down
13 changes: 13 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,19 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sr
*/
TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
const Array<PrimExpr>& factors, bool preserve_unit_iters);

/*!
* \brief Merge a list of loops into one. The loops under their LCA requires:
* 1) Under the same scope.
* 2) Can't have annotations or thread bindings
* 3) Start with 0 and have same domain.
* 4) The inner loop must be the only child of the outer loop.
* \param self The state of the schedule
* \param loop_srefs An array of srefs to the loops to be merged
* \return The new loop after merge
*/
TVM_DLL StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs);

/*!
* \brief Fuse a list of consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
Expand Down
192 changes: 192 additions & 0 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,165 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array
return result_srefs;
}

class LoopReconstructor : private StmtMutator {
public:
explicit LoopReconstructor(Block scope_root, std::vector<std::vector<const ForNode*>> loops)
: scope_root_(scope_root), loops_(loops) {}

using StmtMutator::operator();

/*!
* \brief Create the new nest loops induced by the given loops
*/
void MakeNewLoop() {
Array<Var> new_loop_vars;
Array<PrimExpr> new_loop_extents;
Array<Stmt> new_stmts;
for (size_t i = 0; i < loops_.size(); i++) {
Map<Var, PrimExpr> var_map;
for (size_t j = 0; j < loops_[i].size(); j++) {
if (i == 0) {
std::string suffix = loops_[i][j]->loop_var->name_hint;
suffix += "_m";
int bits = loops_[i][j]->loop_var.dtype().bits();
Var merged_var(suffix, DataType::Int(bits));
new_loop_vars.push_back(merged_var);
new_loop_extents.push_back(loops_[i][j]->extent);
}
var_map.Set(loops_[i][j]->loop_var, new_loop_vars[j]);
}
auto new_stmt = Substitute(loops_[i][0]->body, var_map);
new_stmts.push_back(new_stmt);
this->need_remove_loop_.push_back(loops_[i].back());
}
auto new_loop = For(new_loop_vars[0], Integer(0), new_loop_extents[0], ForKind::kSerial,
SeqStmt(std::move(new_stmts)));
this->new_inner_loop_ = new_loop;
for (size_t i = 1; i < new_loop_vars.size(); ++i) {
const Var& loop_var = new_loop_vars[i];
const PrimExpr& loop_extent = new_loop_extents[i];
new_loop = For(loop_var, Integer(0), loop_extent, ForKind::kSerial, new_loop);
}
this->new_outer_loop_ = new_loop;
}

private:
Stmt VisitStmt_(const BlockNode* block) final {
if (block != scope_root_.get()) {
return GetRef<Block>(block);
}
return StmtMutator::VisitStmt_(block);
}

Stmt VisitStmt_(const ForNode* loop) final {
if (loop == need_remove_loop_.back()) {
return new_outer_loop_;
} else if (std::count(need_remove_loop_.begin(), need_remove_loop_.end(), loop)) {
return Evaluate(0);
}
return StmtMutator::VisitStmt_(loop);
}

Stmt VisitStmt_(const SeqStmtNode* seq_stmt) final {
auto ret = Downcast<SeqStmt>(StmtMutator::VisitSeqStmt_(seq_stmt, true));
Array<Stmt> filtered;
for (Stmt stmt : ret->seq) {
if (!is_no_op(stmt)) {
filtered.push_back(std::move(stmt));
}
}
ret = SeqStmt(filtered);
if (ret->size() == 0) {
return Evaluate(0);
} else if (ret->size() == 1) {
return ret->seq[0];
} else {
return std::move(ret);
}
}

public:
/*! \brief The root block of the block scope */
Block scope_root_;
/*! \brief The given loops to be merge */
std::vector<std::vector<const ForNode*>> loops_;
/*! \brief The outermost new loop to replace the original loop */
For new_outer_loop_{nullptr};
/*! \brief The innermost new loop to replace the original loop */
For new_inner_loop_{nullptr};
/*! \brief The loops to be removed */
std::vector<const ForNode*> need_remove_loop_;
};

StmtSRef Merge(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
// Invariance
// - The total repeat number has not changed for each direct child block.
// - The execution order has not changed. (The block executes with the same
// args and the same order with before.)
arith::Analyzer analyzer;
StmtSRef scope_root_sref;
StmtSRef lca = GetSRefLowestCommonAncestor(loop_srefs);
std::vector<std::vector<const ForNode*>> lca_nest_loops;
// Step 1. check correctness
std::vector<const ForNode*> nest_loop_loops;
std::vector<PrimExpr> nest_loop_extents;
for (size_t i = 0; i < loop_srefs.size(); i++) {
const StmtSRef& sref = loop_srefs[i];
auto scope_root_sref_ = GetScopeRoot(self, sref, /*require_stage_pipeline=*/false);
std::vector<PrimExpr> nest_loop_i_extents;
std::vector<const ForNode*> nest_loop_i_loops;
for (auto p = sref.get(); p != lca.get(); p = p->parent) {
if (auto loop = p->StmtAs<ForNode>()) {
if (!loop->annotations.empty() || loop->thread_binding.defined()) {
throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
}
CheckLoopStartsWithZero(self, GetRef<StmtSRef>(p), &analyzer);
nest_loop_i_loops.push_back(loop);
nest_loop_i_extents.push_back(loop->extent);
}
}
lca_nest_loops.push_back(nest_loop_i_loops);
const ForNode* outer_loop = nullptr;
for (auto iter = nest_loop_i_loops.rbegin(); iter != nest_loop_i_loops.rend(); ++iter) {
if (outer_loop && !outer_loop->body.same_as(GetRef<For>(*iter))) {
throw NotOnlyChildError(self->mod, GetRef<For>(outer_loop), GetRef<For>(*iter));
}
outer_loop = *iter;
}
if (i == 0) {
scope_root_sref = scope_root_sref_;
nest_loop_loops = nest_loop_i_loops;
nest_loop_extents = nest_loop_i_extents;
} else {
if (scope_root_sref_.get() != scope_root_sref.get()) {
LOG(FATAL) << "ScheduleError: Expected the loops to be under the same block scope";
throw;
}
if (nest_loop_i_extents.size() != nest_loop_extents.size()) {
LOG(FATAL) << "ScheduleError: Merge loop's nesting depth must be same, but not";
throw;
} else {
for (size_t j = 0; j < nest_loop_i_extents.size(); j++) {
if (!analyzer.CanProveEqual(nest_loop_i_extents[j], nest_loop_extents[j])) {
LOG(FATAL) << "ScheduleError: Merge loop's `extent` must be same, but not."
<< "extent=[" << j << "," << nest_loop_extents[j] << ","
<< nest_loop_i_extents[j] << "]";
throw;
}
}
}
}
}
// Step 2. Create merged loops and replace the original loops
Block scope_root = GetRef<Block>(scope_root_sref->StmtAs<BlockNode>());
LoopReconstructor reconstructor(scope_root, lca_nest_loops);
reconstructor.MakeNewLoop();
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
// Step 3. Do the actual replacement
self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
return self->stmt2ref.at(reconstructor.new_inner_loop_.get());
}

StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) {
// Invariance
// - The total repeat number has not changed for each direct child block.
Expand Down Expand Up @@ -795,6 +954,38 @@ struct SplitTraits : public UnpackedInstTraits<SplitTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct MergeTraits : public UnpackedInstTraits<MergeTraits> {
static constexpr const char* kName = "Merge";
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumDecisions = 0;

template <size_t delta>
static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter,
const Array<ObjectRef>& inputs) {
setter(delta, inputs);
}

static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) {
return sch->Merge(loop_rvs);
}

static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) {
PythonAPICall py("merge");
for (const String& loop_rv : loop_rvs) {
py.Input("", loop_rv);
}
py.SingleOutput(outputs);
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct FuseTraits : public UnpackedInstTraits<FuseTraits> {
static constexpr const char* kName = "Fuse";
static constexpr bool kIsPure = false;
Expand Down Expand Up @@ -893,6 +1084,7 @@ struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> {
};

TVM_REGISTER_INST_KIND_TRAITS(SplitTraits);
TVM_REGISTER_INST_KIND_TRAITS(MergeTraits);
TVM_REGISTER_INST_KIND_TRAITS(FuseTraits);
TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits);
TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits);
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers")
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers")
.set_body_method<Schedule>(&ScheduleNode::GetConsumers);
/******** (FFI) Transform loops ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method<Schedule>(&ScheduleNode::Merge);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder")
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,16 @@ Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {

/******** Schedule: Transform loops ********/

LoopRV TracedScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
LoopRV result = ConcreteScheduleNode::Merge(loop_rvs);
static const InstructionKind& kind = InstructionKind::Get("Merge");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{loop_rvs.begin(), loop_rvs.end()},
/*attrs=*/{},
/*outputs=*/{result}));
return result;
}

LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_loops) {
LoopRV result = ConcreteScheduleNode::Fuse(loop_rvs, preserve_unit_loops);

Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
LoopRV Merge(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs,
bool preserve_unit_iters) final;
void Reorder(const Array<LoopRV>& ordered_loop_rvs) final;
Expand Down
Loading

0 comments on commit 2c1bb3d

Please sign in to comment.