Skip to content

Commit

Permalink
[TIR][Schedule] reverse_compute_at (#140)
Browse files Browse the repository at this point in the history
* [TIR][Schedule] reverse_compute_at

* [TIR][Schedule] reverse_compute_at: fixed

* [TIR][Schedule] reverse_compute_at: fix
  • Loading branch information
spectrometerHBH authored and jinhongyii committed Jul 29, 2021
1 parent 85f7ad5 commit 8511a28
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 25 deletions.
10 changes: 8 additions & 2 deletions include/tvm/tir/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,19 @@ class ScheduleNode : public Object {
Array<StmtSRef> split(const StmtSRef& loop_sref, const PrimExpr& nparts, const PrimExpr& factor);

/*!
* \brief Move the block under the loop and regenerate the
* loops to cover the producing region.
* \brief Move the block under the loop and regenerate the loops to cover the producing region.
* \param block_sref The block to be moved
* \param loop_sref The target loop
*/
void compute_at(const StmtSRef& block_sref, const StmtSRef& loop_sref);

/*!
* \brief Move the block under the loop and regenerate the loops to cover the producing region.
* \param block_sref The block to be moved
* \param loop_sref The target loop
*/
void reverse_compute_at(const StmtSRef& block_sref, const StmtSRef& loop_sref);

/*!
* \brief Make the block inline
* \param block_sref The sref of the block
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/tir/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,47 @@ def compute_at(self, block, loop):
"""
ScheduleComputeAt(self, block, loop)

def reverse_compute_at(self, block, loop):
"""Attach one block under specific loop and cover the required region.
Node that only complete block can do reverse_compute_at
Parameters
----------
block: Block
The Block to be reverse_compute_at
loop: Loop
The target loop
Example
-------
.. code-block:: python
for i0_outer, i1_outer, i0_inner, i1_inner in tir.grid(8, 8, 16, 16):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, ((i0_outer*16) + i0_inner))
tir.bind(vj, ((i1_outer*16) + i1_inner))
B[vi, vj] = A[vi, vj] * 2 .0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0
After reverse_compute_at(C, i0_inner)
.. code-block:: python
for i0_outer, i1_outer, i1_inner in tir.grid(8, 8, 16):
for i1_inner in range(0, 16):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, ((i0_outer*16) + i0_inner))
tir.bind(vj, ((i1_outer*16) + i1_inner))
B[vi, vj] = A[vi, vj] * 2.0
for ax1 in range(0, 16):
with tir.block([128, 128], "C") as [vi, vj]:
tir.bind(vi, ((i0_outer*16) + i0_inner))
tir.bind(vj, ((i1_outer*16) + ax1))
C[vi, vj] = B[vi, vj] + 1.0
"""
ScheduleReverseComputeAt(self, block, loop)

def bind(self, loop, thread_ivar):
"""Bind ivar to thread index thread_ivar
Parameters
Expand Down
132 changes: 114 additions & 18 deletions src/tir/schedule/schedule_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ Loop RegenerateLoops(const StmtSRef& block_sref, const StmtSRef& loop_sref, int
}
}
// Step 3. Insert the new statement into the children of the loop
Array<Stmt> stmts = GetChildren(GetRef<Stmt>(loop));
Array<Stmt> stmts = GetChildren(GetRef<Stmt>(loop), true);
stmts.insert(stmts.begin() + insert_pos, body);
// Step 4. Create a new loop with those statements as new children to substitute loop_sref->stmt
ObjectPtr<LoopNode> n = make_object<LoopNode>(*loop);
Expand All @@ -218,12 +218,14 @@ Loop RegenerateLoops(const StmtSRef& block_sref, const StmtSRef& loop_sref, int
* \param lca_loop_sref The lca of producer and consumer
* \param consumer_blocks The consumer consumer_blocks
* \param relax_vars The additional vars should be relaxed according to execution scope
* \param gather_read If true(false), gather the read(write) region of consumer_blocks
* \return Required with the same order as produce_regions
*/
std::vector<Range> GatherRequirements(const Array<TensorRegion>& produced_regions,
const StmtSRef& lca_loop_sref,
const std::vector<StmtSRef>& consumer_blocks,
const std::unordered_map<const VarNode*, Range>& relax_vars) {
const std::unordered_map<const VarNode*, Range>& relax_vars,
bool gather_read) {
// For write domain in produce_regions, initiate an empty IntSet for it
std::vector<std::vector<arith::IntSet>> produced_region_reads;
for (const TensorRegion& region : produced_regions) {
Expand All @@ -239,9 +241,13 @@ std::vector<Range> GatherRequirements(const Array<TensorRegion>& produced_region
}
// For each consumer's reading region
for (const StmtSRef& block_sref : consumer_blocks) {
std::vector<TensorRegion> reads;
RelaxRegion(block_sref, lca_loop_sref, &reads, nullptr, relax_vars);
for (const TensorRegion& region : reads) {
std::vector<TensorRegion> relaxed;
if (gather_read) {
RelaxRegion(block_sref, lca_loop_sref, &relaxed, nullptr, relax_vars);
} else {
RelaxRegion(block_sref, lca_loop_sref, nullptr, &relaxed, relax_vars);
}
for (const TensorRegion& region : relaxed) {
const BufferNode* buffer = region->buffer.get();
if (!buffer_indexer.count(buffer)) {
continue;
Expand Down Expand Up @@ -284,8 +290,8 @@ StmtSRef GetSubTreeOfParent(const StmtSRef& node) {
return GetRef<StmtSRef>(child);
}

std::unordered_map<const VarNode*, Range> RelaxForExeScope(const StmtSRef& loop_sref,
const StmtSRef& block_sref) {
std::unordered_map<const VarNode*, Range> RelaxForExecScope(const StmtSRef& loop_sref,
const StmtSRef& block_sref) {
std::unordered_map<const VarNode*, Range> relax_var;
const auto* block = block_sref->GetStmt<BlockNode>();
const BlockRealize& realize = GetBlockRealize(block_sref);
Expand Down Expand Up @@ -331,10 +337,9 @@ void ScheduleNode::compute_at(const StmtSRef& block_sref, const StmtSRef& loop_s
/*!
* Check:
* - check input_block is complete/is a dominant reduction block
* - check input_block's RAW predecessors are complete
* - check dependency: all input_block's RAW successors are under input_loop
* - check one-way fine-grained data flow: all blocks in the same sub tree are complete
* - check block is not a output block
* - check all blocks in the same sub tree are complete
* - check block is not an output block
*
* Mutate:
* - generate loops that iterate the whole instance space under
Expand Down Expand Up @@ -365,14 +370,14 @@ void ScheduleNode::compute_at(const StmtSRef& block_sref, const StmtSRef& loop_s
// Cond 1. 'block' is complete/reduction block
CHECK(scope.IsComplete(block_sref) || scope.IsReduction(block_sref))
<< "ValueError: 'compute_at' expects 'block' to be a complete or reduction block";
// Cond 3. Check all RAW successors are in the subtree rooted by loop_sref
// Cond 2. Check all RAW successors are in the subtree rooted by loop_sref
CHECK(EachEdgePointsToABlock(edges_to_succ, GetChildBlocks(loop_sref), /*raw_edge_only=*/true))
<< "ValueError: 'compute_at' does not apply to a block that some other "
<< "blocks outside the scope depends on";
// Cond 4. The subtree has compact data flow
// Cond 3. The subtree has compact data flow
CHECK(scope.IsCompactDataFlow(GetSubTreeOfParent(block_sref), this))
<< "ValueError: 'compute_at' expects the subtree of 'block' to have compact dataflow";
// Cond 5. Check the block is not a output block
// Cond 4. Check the block is not a output block
for (const TensorRegion& parent_write : parent_block->writes) {
for (const TensorRegion& write : block->writes) {
CHECK_NE(write->buffer.get(), parent_write->buffer.get())
Expand Down Expand Up @@ -406,11 +411,102 @@ void ScheduleNode::compute_at(const StmtSRef& block_sref, const StmtSRef& loop_s
// Generate new LoopNode to substitute loop_sref->stmt
Loop new_loop = RegenerateLoops(
block_sref, loop_sref, insert_pos,
SolveCover(block,
GatherRequirements(/*produced_regions=*/block->writes,
/*lca_loop_sref=*/loop_sref,
/*consumer_blocks=*/EdgesToSRefs(edges_to_succ),
/*relax_vars=*/RelaxForExeScope(loop_sref, block_sref))));
SolveCover(block, GatherRequirements(/*produced_regions=*/block->writes,
/*lca_loop_sref=*/loop_sref,
/*consumer_blocks=*/EdgesToSRefs(edges_to_succ),
/*relax_vars=*/RelaxForExecScope(loop_sref, block_sref),
/*gather_read=*/true)));
// Remove leaf
std::pair<Stmt, Stmt> removed = RemoveLeaf(block_sref, this->root);
std::unordered_map<const StmtNode*, const StmtNode*> replace_map = {
{removed.first.get(), removed.second.get()},
{loop_sref->stmt, new_loop.get()},
};
// Mutate the AST with Replace
StmtSRef lca = LowestCommonAncestor({block_sref, loop_sref}, this->root);
Stmt replaced = StmtReplacer(replace_map)(GetRef<Stmt>(lca->stmt));
if (const auto* replaced_block = replaced.as<BlockNode>()) {
this->Replace(lca, replaced, {{GetRef<Block>(replaced_block), GetRef<Block>(parent_block)}});
} else {
this->Replace(lca, replaced);
}
}

void ScheduleNode::reverse_compute_at(const StmtSRef& block_sref, const StmtSRef& loop_sref) {
/*!
* Check:
* - check input_block is complete/is a dominant reduction block
* - check all input_block's RAW predecessors are under input_loop
* - check all blocks in the same sub tree are complete
* - check all input_block's RAW predecessors are complete/dominant reduction block
*
* Mutate:
* - generate loops that iterate the whole instance space under
* input_loop after all the predecessors
*/
const auto* block = block_sref->GetStmt<BlockNode>();
const auto* loop = loop_sref->GetStmt<LoopNode>();
CHECK(block != nullptr)
<< "TypeError: 'reverse_compute_at' expects 'block' to be a block, but get type: "
<< block_sref->stmt->GetTypeKey();
CHECK(loop != nullptr)
<< "TypeError: 'reverse_compute_at' expects 'loop' to be a loop, but get type: "
<< loop_sref->stmt->GetTypeKey();
const StmtSRef& parent_block_sref = GetParentBlockSRef(block_sref);
const BlockNode* parent_block = parent_block_sref->GetStmt<BlockNode>();
const Scope& scope = scopes.at(parent_block_sref);
Array<DepEdge> edges_to_pred = scope.GetPredecessors(block_sref);
Array<DepEdge> edges_to_succ = scope.GetSuccessors(block_sref);
// Cond 0. `block` and `loop` are in the same scope
CHECK_EQ(parent_block_sref.get(), GetParentBlockSRef(loop_sref).get())
<< "ValueError: 'reverse_compute_at' expects 'block' and 'loop' be in the same block";
// Cond 1. 'block' is complete/reduction block
CHECK(scope.IsComplete(block_sref) || scope.IsReduction(block_sref))
<< "ValueError: 'reverse_compute_at' expects 'block' to be a complete or reduction block";
// Cond 2. Check all RAW predecessors are in the subtree rooted by loop_sref
CHECK(EachEdgePointsToABlock(edges_to_pred, GetChildBlocks(loop_sref), /*raw_edge_only=*/true))
<< "ValueError: 'reverse_compute_at' does not apply to a block that some other "
<< "blocks outside the scope depends on";
// Cond 3. The subtree has compact data flow
CHECK(scope.IsCompactDataFlow(GetSubTreeOfParent(block_sref), this))
<< "ValueError: 'reverse_compute_at' expects the subtree of 'block' to have compact dataflow";
// Cond 4. Check all RAW predecessors are complete/reduction block
for (const auto& edge : edges_to_pred)
CHECK(scope.IsComplete(edge->dst) || scope.IsReduction(edge->dst))
<< "ValueError: 'reverse_compute_at' expects producers of 'block' to be a complete or "
"reduction block";
// Mutation
// Step 1. Find insertion position
int insert_pos;
{
// After all predecessors in dependency graph
Array<Stmt> loop_body = GetChildren(GetRef<Stmt>(loop));
int n_stmts = loop_body.size();
for (insert_pos = n_stmts; insert_pos > 0; --insert_pos) {
const StmtNode* stmt = loop_body[insert_pos - 1].get();
if (AnyEdgePointsToABlock(edges_to_pred, GetChildBlocks(stmt2ref.at(stmt)))) {
break;
}
}
// Before all successors in dep graph.
int before_pos;
for (before_pos = 0; before_pos < n_stmts; before_pos++) {
const StmtNode* stmt = loop_body[before_pos].get();
if (AnyEdgePointsToABlock(edges_to_succ, GetChildBlocks(stmt2ref.at(stmt)))) {
break;
}
}
CHECK(insert_pos <= before_pos) << "ValueError: 'reverse_compute_at' cannot find an insertion "
"point that satisfies dependency";
}
// Generate new LoopNode to substitute loop_sref->stmt
Loop new_loop = RegenerateLoops(
block_sref, loop_sref, insert_pos,
SolveCover(block, GatherRequirements(/*produced_regions=*/block->reads,
/*lca_loop_sref=*/loop_sref,
/*consumer_blocks=*/EdgesToSRefs(edges_to_pred),
/*relax_vars=*/{},
/*gather_read=*/false)));
// Remove leaf
std::pair<Stmt, Stmt> removed = RemoveLeaf(block_sref, this->root);
std::unordered_map<const StmtNode*, const StmtNode*> replace_map = {
Expand Down
10 changes: 5 additions & 5 deletions src/tir/schedule/schedule_validate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ bool IsAllUniqueVars(const std::vector<PrimExpr>& list) {
* If so, it provides two functions, replace and postproc, for replacing this pattern
* and removing them
*/
class FuseSplitDetecter : public ExprVisitor {
class FuseSplitDetector : public ExprVisitor {
public:
/*! \brief Constructor */
explicit FuseSplitDetecter(std::unordered_map<const VarNode*, PrimExpr>* loop_var_extents)
explicit FuseSplitDetector(std::unordered_map<const VarNode*, PrimExpr>* loop_var_extents)
: loop_var_extents(loop_var_extents) {}

/*! \brief Check if the PrimExpr is in fuse pattern. If so, set replace and postproc for it */
Expand Down Expand Up @@ -260,7 +260,7 @@ class FuseSplitDetecter : public ExprVisitor {
class FuseSplitNormalizer : public ExprMutator {
public:
/*! \brief Constructor */
explicit FuseSplitNormalizer(const FuseSplitDetecter& detector) : detector_(detector) {}
explicit FuseSplitNormalizer(const FuseSplitDetector& detector) : detector_(detector) {}
/*! \brief Destructor. Invoke postproc only if replacement happens at least once. */
~FuseSplitNormalizer() {
if (replaced_) {
Expand All @@ -280,7 +280,7 @@ class FuseSplitNormalizer : public ExprMutator {

private:
/*! \brief The detector that has detected some pattern */
const FuseSplitDetecter& detector_;
const FuseSplitDetector& detector_;
/*! \brief Indicating if replacement happens at least once */
bool replaced_ = false;
};
Expand Down Expand Up @@ -316,7 +316,7 @@ class LoopValidator : public StmtVisitor {
std::vector<std::pair<PrimExpr, PrimExpr>> predicates = SplitPredicate(realize->predicate);
for (;;) {
// Detect fuse/split pattern
FuseSplitDetecter detector(&loop_vars);
FuseSplitDetector detector(&loop_vars);
for (const auto& binding : bindings) {
detector(binding);
if (detector.replace) {
Expand Down
41 changes: 41 additions & 0 deletions tests/python/tir/test_schedule_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,46 @@ def test_compute_at():
assert s.validate_sref()


@tvm.hybrid.script
def reverse_compute_at_element_wise(a: ty.handle, c: ty.handle) -> None:
# function attr dict
C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1)
A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1)
B = tir.buffer_allocate([128, 128], elem_offset=0, align=128, offset_factor=1)

# body
for i0_outer in range(0, 8):
for i1_outer in range(0, 8):
for i0_inner in range(0, 16):
for i1_inner in range(0, 16):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, ((i0_outer*16) + i0_inner))
tir.bind(vj, ((i1_outer*16) + i1_inner))
B[vi, vj] = (A[vi, vj]*tir.float32(2))
for ax1 in range(0, 16):
with tir.block([128, 128], "C") as [vi, vj]:
tir.bind(vi, ((i0_outer*16) + i0_inner))
tir.bind(vj, ((i1_outer*16) + ax1))
C[vi, vj] = (B[vi, vj] + tir.float32(1))


def test_reverse_compute_at():
func = util.element_wise_stmt()

# schedule
s = tir.create_schedule(func)
B = s.get_block("B")
C = s.get_block("C")
i, j = s.get_axes(B)
i1, i2 = s.split(i, 16)
j1, j2 = s.split(j, 16)
s.reorder(i1, j1, i2, j2)
s.reverse_compute_at(C, i2)

tvm.ir.assert_structural_equal(reverse_compute_at_element_wise, s.func)
assert s.validate_sref()


@tvm.hybrid.script
def predicate_fuse(b: ty.handle, c: ty.handle) -> None:
C = tir.match_buffer(c, (16, 16), "float32")
Expand Down Expand Up @@ -493,6 +533,7 @@ def test_cache_read_write():
test_fuse_loop_sref()
test_reorder_normal()
test_compute_at()
test_reverse_compute_at()
test_compute_inline()
test_compute_at_fail()
test_reduction()
Expand Down

0 comments on commit 8511a28

Please sign in to comment.