Skip to content

Commit

Permalink
[TIR][Schedule] Analysis functions to check if compute_inline and com… (
Browse files Browse the repository at this point in the history
apache#9743)

* [TIR][Schedule] Analysis functions to check if compute_inline and compute_inline is allowed

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>

* Address comments

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Xiyou Zhou <xiyou@octoml.ai>
  • Loading branch information
7 people authored and ylc committed Jan 13, 2022
1 parent 7fd300f commit 3115e85
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 14 deletions.
41 changes: 41 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,47 @@ std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters()
bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs);

/******** Misc ********/

/*!
* \brief Checks if a block could be successfully computed inline into its consumer
* \param self The schedule state
* \param block_sref The block to be checked
* \return A boolean indicating whether the block could be successfully computed inline
*/
bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Checks if a block could be successfully computed inline into its producer
* \param self The schedule state
* \param block_sref The block to be checked
* \return A boolean indicating whether the block could be successfully computed inline
*/
bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Checks if a producer block could be successfully computed at the specific loop.
* \param self The schedule state
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \return A boolean indicating whether the block could be successfully compute at the specific loop
*/
bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops);

/*!
* \brief Checks if a consumer block could be successfully computed at the specific loop.
* \param self The schedule state
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \return A boolean indicating whether the block could be successfully reverse compute at the
* specific loop
*/
bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops);

} // namespace tir
} // namespace tvm

Expand Down
46 changes: 39 additions & 7 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ void CalculateProvidedRequiredRegions(

template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops) {
const StmtSRef& loop_sref, bool preserve_unit_loops,
arith::Analyzer* analyzer, bool check_only = false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
Expand All @@ -463,11 +464,10 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
arith::Analyzer analyzer;
// Check condition 3): `block` and `loop` are under the same scope,
// and `loop` is not the ancestor of `block`
NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref,
&analyzer);
analyzer);
// Check condition 4): `block` is not an output block
if (is_compute_at) {
CheckNotOutputBlock(self, block_sref, scope_root_sref);
Expand Down Expand Up @@ -501,29 +501,61 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars,
/*provided_regions=*/std::move(provided_regions),
/*required_regions=*/std::move(required_regions),
/*analyzer=*/&analyzer);
/*analyzer=*/analyzer);
// Step 6. Create the new scope according to the iteration domain
reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms),
/*preserve_unit_loops=*/preserve_unit_loops);
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));

// Step 7. Do the actual replacement
if (check_only) {
return;
}
self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}});
// Step 8. Update the cached flags
BlockInfo& block_info = self->block_info[block_sref];
block_info.affine_binding = IsAffineBinding(
/*realize=*/reconstructor.new_block_realize_,
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent)),
/*analyzer=*/&analyzer);
/*analyzer=*/analyzer);
}

void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops);
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
}

void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops);
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
}

bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
arith::Analyzer analyzer;
try {
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}

bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops) {
arith::Analyzer analyzer;
try {
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}

/******** InstructionKind Registration ********/
Expand Down
66 changes: 59 additions & 7 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,27 @@ class NotSingleReadWriteBuffer : public ScheduleError {
bool is_read_;
Block block_;

static Buffer GetSingleRead(const ScheduleState& self, const Block& block) {
if (block->reads.size() != 1) {
static Buffer GetSingleRead(const ScheduleState& self, const Block& block,
const StmtSRef& scope_root_sref) {
const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual>&
buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers;
const BufferNode* read_buffer = nullptr;
for (const BufferRegion& read_region : block->reads) {
const BufferNode* buffer = read_region->buffer.get();
if (buffer == read_buffer) {
continue;
}
if (buffer_writers.count(GetRef<Buffer>(buffer)) > 0) {
if (read_buffer != nullptr) {
throw NotSingleReadWriteBuffer(self->mod, true, block);
}
read_buffer = buffer;
}
}
if (read_buffer == nullptr) {
throw NotSingleReadWriteBuffer(self->mod, true, block);
}
return block->reads[0]->buffer;
return GetRef<Buffer>(read_buffer);
}

static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) {
Expand Down Expand Up @@ -167,7 +183,7 @@ class OpaqueAccessError : public ScheduleError {
* \brief The base class of the inliner, which handles:
* 1) Substitute a subtree with the specific block being inlined
* 2) Update the block signature to reflect the changes of read/write/allocated buffers
* 3) Maintain a list of index variables and their substition of the buffer being inlined
* 3) Maintain a list of index variables and their substitution of the buffer being inlined
*/
class BaseInliner : public StmtExprMutator {
protected:
Expand Down Expand Up @@ -526,7 +542,8 @@ class ReverseComputeInliner : public BaseInliner {
PrimExpr producer_rhs_{nullptr};
};

void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
bool check_only = false) {
const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref);
Block producer_block = GetRef<Block>(_producer_block);
Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block);
Expand All @@ -535,6 +552,7 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
// Step 2. Check completeness
CheckNotOutputBlock(self, producer_block_sref, scope_root_sref);
CheckCompleteBlock(self, producer_block_sref, scope_root_sref);
// Step 3. Analyze the block body
ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref);
Expand All @@ -550,17 +568,35 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
throw OpaqueAccessError(self->mod, scope_root_sref);
}
// Step 6. Do the real mutation on the AST and the sref tree in the schedule state
if (check_only) {
return;
}
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
}

void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) {
void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
ComputeInlineImpl(self, producer_block_sref);
}

bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) {
try {
ComputeInlineImpl(self, producer_block_sref, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}

void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref,
bool check_only = false) {
const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref);
Block consumer_block = GetRef<Block>(_consumer_block);
Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
/*require_stage_pipeline=*/true,
/*require_subtree_compact_dataflow=*/false);
Buffer inlined_buffer =
NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref);
// Step 2. Check completeness
CheckCompleteBlock(self, consumer_block_sref, scope_root_sref);
// Step 3. Check if the consumer has a single complete producer
Expand All @@ -579,9 +615,25 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre
throw OpaqueAccessError(self->mod, scope_root_sref);
}
// Step 7. Do the real mutation on the AST and the sref tree in the schedule state
if (check_only) {
return;
}
self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
}

bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) {
try {
ReverseComputeInlineImpl(self, block_sref, true);
} catch (const tvm::runtime::Error& e) {
return false;
}
return true;
}

void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) {
ReverseComputeInlineImpl(self, consumer_block_sref);
}

/******** InstructionKind Registration ********/

struct ComputeInlineTraits : public UnpackedInstTraits<ComputeInlineTraits> {
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,28 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None:
B[vi] = A_cache[vi] * 2.0 + 1.0


@T.prim_func
def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None:
A = T.match_buffer(var_A, [512, 512], dtype="float32")
B = T.match_buffer(var_B, [512, 512], dtype="float32")
compute = T.match_buffer(var_compute, [512, 512], dtype="float32")
C = T.alloc_buffer([512, 512], dtype="float32")
for i0, i1, i2 in T.grid(512, 512, 512):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads([C[i, j], A[i, k], B[k, j]])
T.writes([C[i, j]])
with T.init():
C[i, j] = T.float32(0)
C[i, j] = C[i, j] + A[i, k] * B[k, j]
for i0, i1 in T.grid(512, 512):
with T.block("compute"):
i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
T.reads([C[i0_1, i1_1]])
T.writes([compute[i0_1, i1_1]])
compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))


# pylint: enable=no-member,invalid-name,unused-variable


Expand Down Expand Up @@ -458,6 +480,13 @@ def test_buffer_matched():
sch.compute_inline(block_b)


def test_output_block():
sch = tir.Schedule(matmul_relu, debug_mask="all")
block = sch.get_block("compute")
with pytest.raises(tvm.tir.ScheduleError):
sch.compute_inline(block)


def test_compute_inline_predicate():
sch = tir.Schedule(elementwise_predicate, debug_mask="all")
block_b = sch.get_block("B")
Expand Down

0 comments on commit 3115e85

Please sign in to comment.