Skip to content

Commit

Permalink
[Meta Schedule Refactor] Get child blocks (#500)
Browse files Browse the repository at this point in the history
* get child blocks

* clang format

* black

* test

* test

* test
  • Loading branch information
spectrometerHBH authored and junrushao committed Nov 5, 2021
1 parent babb0d9 commit fa3f277
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 0 deletions.
12 changes: 12 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,18 @@ class ScheduleNode : public runtime::Object {
* \return A list of loops above the given block in its scope, from outer to inner
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
/*!
* \brief Get the leaf blocks of a specific scope
* \param block_rv The block where the scope is rooted
* \return A list of child blocks
*/
virtual Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) = 0;
/*!
* \brief Get the leaf blocks of under a specific loop
* \param loop_rv The loop under which collecting is conducted
* \return A list of child blocks
*/
virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
/******** Schedule: Transform loops ********/
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,21 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
"""
return _ffi_api.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member

def get_child_blocks(self, block_or_loop: Union[BlockRV, LoopRV]) -> List[BlockRV]:
"""Get the leaf blocks of a specific block/loop
Parameters
----------
block_or_loop : Union[BlockRV, LoopRV]
The query block/loop
Returns
-------
blocks : List[LoopRV]
A list of leaf blocks inside a specific block/loop
"""
return _ffi_api.ScheduleGetChildBlocks(self, block_or_loop) # pylint: disable=no-member

########## Schedule: Transform loops ##########
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
Expand Down
18 changes: 18 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,24 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) {
return CreateRV<LoopRV>(tir::GetLoops(this->GetSRef(block_rv)));
}

Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const BlockRV& block_rv) {
Array<BlockRV> result;
TVM_TIR_SCHEDULE_BEGIN();
result = CreateRV<BlockRV>(tir::GetChildBlocks(state_, this->GetSRef(block_rv), false));
TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_);
this->state_->DebugVerify();
return result;
}

Array<BlockRV> ConcreteScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
Array<BlockRV> result;
TVM_TIR_SCHEDULE_BEGIN();
result = CreateRV<BlockRV>(tir::GetChildBlocks(state_, this->GetSRef(loop_rv), false));
TVM_TIR_SCHEDULE_END("get-child-blocks", this->error_render_level_);
this->state_->DebugVerify();
return result;
}

/******** Schedule: Transform loops ********/

LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Get blocks & loops ********/
BlockRV GetBlock(const String& name, const String& func_name = "main") override;
Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const S
* \return A list of loops above the given block in its scope, from outer to inner
*/
Array<StmtSRef> GetLoops(const StmtSRef& block_sref);
/*!
* \brief Get the leaf blocks of a specific block/loop
* \param self The schedule state
* \param parent_sref The query block/loop
* \param inclusive Whether to include parent_sref
* \return A list of leaf blocks inside a specific block/loop
*/
Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref,
bool inclusive = false);
/******** Schedule: Transform loops ********/
/*!
* Split a loop into a list of consecutive loops. It requires:
Expand Down
56 changes: 56 additions & 0 deletions src/tir/schedule/primitive/get_block_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,31 @@ Array<StmtSRef> GetLoops(const StmtSRef& block_sref) {
return {result.rbegin(), result.rend()};
}

Array<StmtSRef> GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref,
bool inclusive) {
struct Collector : public StmtVisitor {
private:
void VisitStmt_(const BlockNode* block) final { result.push_back(self->stmt2ref.at(block)); }

public:
explicit Collector(const ScheduleState& self) : self(self) {}

const ScheduleState& self;
Array<StmtSRef> result;
};
Collector collector(self);
if (inclusive) {
collector(GetRef<Stmt>(parent_sref->stmt));
} else if (parent_sref->stmt->IsInstance<ForNode>()) {
const auto* loop = static_cast<const ForNode*>(parent_sref->stmt);
collector(loop->body);
} else if (parent_sref->stmt->IsInstance<BlockNode>()) {
const auto* block = static_cast<const BlockNode*>(parent_sref->stmt);
collector(block->body);
}
return std::move(collector.result);
}

/******** InstructionKind Registration ********/

struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> {
Expand Down Expand Up @@ -106,8 +131,39 @@ struct GetLoopsTraits : public UnpackedInstTraits<GetLoopsTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct GetChildBlocksTraits : public UnpackedInstTraits<GetChildBlocksTraits> {
static constexpr const char* kName = "GetChildBlocks";
static constexpr bool kIsPure = true;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumDecisions = 0;

static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv) {
if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
return sch->GetChildBlocks(GetRef<BlockRV>(block));
}
if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
return sch->GetChildBlocks(GetRef<LoopRV>(loop));
}
LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " << block_or_loop_rv->GetTypeKey();
throw;
}

static String UnpackedAsPython(Array<String> outputs, String block_or_loop_rv) {
PythonAPICall py("get_child_blocks");
py.Input("", block_or_loop_rv);
py.OutputList(outputs);
return py.Str();
}

friend struct UnpackedInstTraits;
};

TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits);

} // namespace tir
} // namespace tvm
12 changes: 12 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock")
.set_body_method<Schedule>(&ScheduleNode::GetBlock);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops")
.set_body_method<Schedule>(&ScheduleNode::GetLoops);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetChildBlocks")
.set_body_typed([](Schedule self, ObjectRef rv) {
if (const auto* block_rv = rv.as<BlockRVNode>()) {
return self->GetChildBlocks(GetRef<BlockRV>(block_rv));
}
if (const auto* loop_rv = rv.as<LoopRVNode>()) {
return self->GetChildBlocks(GetRef<LoopRV>(loop_rv));
}
LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey()
<< ". Its value is: " << rv;
throw;
});
/******** (FFI) Transform loops ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
Expand Down
22 changes: 22 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,28 @@ Array<LoopRV> TracedScheduleNode::GetLoops(const BlockRV& block_rv) {
return results;
}

Array<BlockRV> TracedScheduleNode::GetChildBlocks(const BlockRV& block_rv) {
Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(block_rv);

static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{block_rv},
/*attrs=*/{},
/*outputs=*/{results.begin(), results.end()}));
return results;
}

Array<BlockRV> TracedScheduleNode::GetChildBlocks(const LoopRV& loop_rv) {
Array<BlockRV> results = ConcreteScheduleNode::GetChildBlocks(loop_rv);

static const InstructionKind& kind = InstructionKind::Get("GetChildBlocks");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{loop_rv},
/*attrs=*/{},
/*outputs=*/{results.begin(), results.end()}));
return results;
}

/******** Schedule: Transform loops ********/

LoopRV TracedScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
/******** Schedule: Get blocks & loops ********/
BlockRV GetBlock(const String& name, const String& func_name = "main") final;
Array<LoopRV> GetLoops(const BlockRV& block_rv) final;
Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) final;
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) final;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factor_rvs) final;
Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_tir_schedule_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,22 @@ def test_tir_schedule_remove_rv():
sch.get(block_rv)


def test_get_child_blocks():
s = tir.Schedule(matmul, debug_mask="all")
init = s.get_block("init")
update = s.get_block("update")
# loop
blocks = s.get_child_blocks(s.get_loops(init)[0])
assert len(blocks) == 2
assert s.get(init) == s.get(blocks[0])
assert s.get(update) == s.get(blocks[1])
# block
root = s.get_block("root")
blocks = s.get_child_blocks(root)
assert len(blocks) == 2
assert s.get(init) == s.get(blocks[0])
assert s.get(update) == s.get(blocks[1])


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit fa3f277

Please sign in to comment.