diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 42e0e00995fe..82f4afa7a24c 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -393,6 +393,47 @@ std::vector> GetReducerGetters() bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); +/******** Misc ********/ + +/*! + * \brief Checks if a block could be successfully computed inline into its consumer + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block could be successfully computed inline + */ +bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if a block could be successfully computed inline into its producer + * \param self The schedule state + * \param block_sref The block to be checked + * \return A boolean indicating whether the block could be successfully computed inline + */ +bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref); + +/*! + * \brief Checks if a producer block could be successfully computed at the specific loop. + * \param self The schedule state + * \param block_sref The block to be moved + * \param loop_sref The loop where the block to be moved to + * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + * \return A boolean indicating whether the block could be successfully compute at the specific loop + */ +bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, + bool preserve_unit_loops); + +/*! + * \brief Checks if a consumer block could be successfully computed at the specific loop. + * \param self The schedule state + * \param block_sref The block to be moved + * \param loop_sref The loop where the block to be moved to + * \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1 + * \return A boolean indicating whether the block could be successfully reverse compute at the + * specific loop + */ +bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, bool preserve_unit_loops); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 0dae50abc05e..00886e8f8a22 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -451,7 +451,8 @@ void CalculateProvidedRequiredRegions( template void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref, - const StmtSRef& loop_sref, bool preserve_unit_loops) { + const StmtSRef& loop_sref, bool preserve_unit_loops, + arith::Analyzer* analyzer, bool check_only = false) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); // Step 1. Bunch of checks @@ -463,11 +464,10 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s BlockScope scope = self->GetBlockScope(scope_root_sref); Array producer_srefs = GetProducers(block_sref, scope); Array consumer_srefs = GetConsumers(block_sref, scope); - arith::Analyzer analyzer; // Check condition 3): `block` and `loop` are under the same scope, // and `loop` is not the ancestor of `block` NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, scope_root_sref, - &analyzer); + analyzer); // Check condition 4): `block` is not an output block if (is_compute_at) { CheckNotOutputBlock(self, block_sref, scope_root_sref); @@ -501,29 +501,61 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars, /*provided_regions=*/std::move(provided_regions), /*required_regions=*/std::move(required_regions), - /*analyzer=*/&analyzer); + /*analyzer=*/analyzer); // Step 6. Create the new scope according to the iteration domain reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms), /*preserve_unit_loops=*/preserve_unit_loops); Block new_scope_root = Downcast(reconstructor(scope_root)); + // Step 7. Do the actual replacement + if (check_only) { + return; + } self->Replace(scope_root_sref, new_scope_root, {{scope_root, new_scope_root}}); // Step 8. Update the cached flags BlockInfo& block_info = self->block_info[block_sref]; block_info.affine_binding = IsAffineBinding( /*realize=*/reconstructor.new_block_realize_, /*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef(block_sref->parent)), - /*analyzer=*/&analyzer); + /*analyzer=*/analyzer); } void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { - ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops); + arith::Analyzer analyzer; + ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, + &analyzer); } void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref, bool preserve_unit_loops) { - ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops); + arith::Analyzer analyzer; + ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, + &analyzer); +} + +bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref, + bool preserve_unit_loops) { + arith::Analyzer analyzer; + try { + ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, + &analyzer, true); + } catch (const tvm::runtime::Error& e) { + return false; + } + return true; +} + +bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& loop_sref, bool preserve_unit_loops) { + arith::Analyzer analyzer; + try { + ComputeAtOrReverseComputeAtImpl(self, block_sref, loop_sref, preserve_unit_loops, + &analyzer, true); + } catch (const tvm::runtime::Error& e) { + return false; + } + return true; } /******** InstructionKind Registration ********/ diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 12ae021a88ee..fe2c679142b7 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -60,11 +60,27 @@ class NotSingleReadWriteBuffer : public ScheduleError { bool is_read_; Block block_; - static Buffer GetSingleRead(const ScheduleState& self, const Block& block) { - if (block->reads.size() != 1) { + static Buffer GetSingleRead(const ScheduleState& self, const Block& block, + const StmtSRef& scope_root_sref) { + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& + buffer_writers = self->block_info.at(scope_root_sref).scope->buffer_writers; + const BufferNode* read_buffer = nullptr; + for (const BufferRegion& read_region : block->reads) { + const BufferNode* buffer = read_region->buffer.get(); + if (buffer == read_buffer) { + continue; + } + if (buffer_writers.count(GetRef(buffer)) > 0) { + if (read_buffer != nullptr) { + throw NotSingleReadWriteBuffer(self->mod, true, block); + } + read_buffer = buffer; + } + } + if (read_buffer == nullptr) { throw NotSingleReadWriteBuffer(self->mod, true, block); } - return block->reads[0]->buffer; + return GetRef(read_buffer); } static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) { @@ -167,7 +183,7 @@ class OpaqueAccessError : public ScheduleError { * \brief The base class of the inliner, which handles: * 1) Substitute a subtree with the specific block being inlined * 2) Update the block signature to reflect the changes of read/write/allocated buffers - * 3) Maintain a list of index variables and their substition of the buffer being inlined + * 3) Maintain a list of index variables and their substitution of the buffer being inlined */ class BaseInliner : public StmtExprMutator { protected: @@ -526,7 +542,8 @@ class ReverseComputeInliner : public BaseInliner { PrimExpr producer_rhs_{nullptr}; }; -void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { +void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, + bool check_only = false) { const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, producer_block_sref); Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); @@ -535,6 +552,7 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { /*require_stage_pipeline=*/true, /*require_subtree_compact_dataflow=*/false); // Step 2. Check completeness + CheckNotOutputBlock(self, producer_block_sref, scope_root_sref); CheckCompleteBlock(self, producer_block_sref, scope_root_sref); // Step 3. Analyze the block body ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref); @@ -550,17 +568,35 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 6. Do the real mutation on the AST and the sref tree in the schedule state + if (check_only) { + return; + } self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); } -void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { +void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { + ComputeInlineImpl(self, producer_block_sref); +} + +bool CanComputeInline(const ScheduleState& self, const StmtSRef& producer_block_sref) { + try { + ComputeInlineImpl(self, producer_block_sref, true); + } catch (const tvm::runtime::Error& e) { + return false; + } + return true; +} + +void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block_sref, + bool check_only = false) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, consumer_block_sref); Block consumer_block = GetRef(_consumer_block); - Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // /*require_stage_pipeline=*/true, /*require_subtree_compact_dataflow=*/false); + Buffer inlined_buffer = + NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, scope_root_sref); // Step 2. Check completeness CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); // Step 3. Check if the consumer has a single complete producer @@ -579,9 +615,25 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre throw OpaqueAccessError(self->mod, scope_root_sref); } // Step 7. Do the real mutation on the AST and the sref tree in the schedule state + if (check_only) { + return; + } self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse); } +bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref) { + try { + ReverseComputeInlineImpl(self, block_sref, true); + } catch (const tvm::runtime::Error& e) { + return false; + } + return true; +} + +void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sref) { + ReverseComputeInlineImpl(self, consumer_block_sref); +} + /******** InstructionKind Registration ********/ struct ComputeInlineTraits : public UnpackedInstTraits { diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index a078c0ed4c23..5cc36c0df878 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -329,6 +329,28 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: T.handle) -> None: B[vi] = A_cache[vi] * 2.0 + 1.0 +@T.prim_func +def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> None: + A = T.match_buffer(var_A, [512, 512], dtype="float32") + B = T.match_buffer(var_B, [512, 512], dtype="float32") + compute = T.match_buffer(var_compute, [512, 512], dtype="float32") + C = T.alloc_buffer([512, 512], dtype="float32") + for i0, i1, i2 in T.grid(512, 512, 512): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads([C[i, j], A[i, k], B[k, j]]) + T.writes([C[i, j]]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + for i0, i1 in T.grid(512, 512): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads([C[i0_1, i1_1]]) + T.writes([compute[i0_1, i1_1]]) + compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) + + # pylint: enable=no-member,invalid-name,unused-variable @@ -458,6 +480,13 @@ def test_buffer_matched(): sch.compute_inline(block_b) +def test_output_block(): + sch = tir.Schedule(matmul_relu, debug_mask="all") + block = sch.get_block("compute") + with pytest.raises(tvm.tir.ScheduleError): + sch.compute_inline(block) + + def test_compute_inline_predicate(): sch = tir.Schedule(elementwise_predicate, debug_mask="all") block_b = sch.get_block("B")