From 7ce474ae0a0fb56fb3b5169b2cc12c7f16068c78 Mon Sep 17 00:00:00 2001 From: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Date: Tue, 6 Oct 2020 21:05:09 +0800 Subject: [PATCH] [TIR][Schedule] reverse_compute_at (#140) * [TIR][Schedule] reverse_compute_at * [TIR][Schedule] reverse_compute_at: fixed * [TIR][Schedule] reverse_compute_at: fix --- include/tvm/tir/schedule.h | 10 +- python/tvm/tir/schedule.py | 41 ++++++ src/tir/schedule/schedule.cc | 6 + src/tir/schedule/schedule_compute_location.cc | 132 +++++++++++++++--- src/tir/schedule/schedule_validate.cc | 10 +- tests/python/tir/test_schedule_primitive.py | 41 ++++++ 6 files changed, 215 insertions(+), 25 deletions(-) diff --git a/include/tvm/tir/schedule.h b/include/tvm/tir/schedule.h index b2675be64e..624cdb32ae 100644 --- a/include/tvm/tir/schedule.h +++ b/include/tvm/tir/schedule.h @@ -139,13 +139,19 @@ class ScheduleNode : public Object { Array split(const StmtSRef& loop_sref, const PrimExpr& nparts, const PrimExpr& factor); /*! - * \brief Move the block under the loop and regenerate the - * loops to cover the producing region. + * \brief Move the block under the loop and regenerate the loops to cover the producing region. * \param block_sref The block to be moved * \param loop_sref The target loop */ void compute_at(const StmtSRef& block_sref, const StmtSRef& loop_sref); + /*! + * \brief Move the block under the loop and regenerate the loops to cover the producing region. + * \param block_sref The block to be moved + * \param loop_sref The target loop + */ + void reverse_compute_at(const StmtSRef& block_sref, const StmtSRef& loop_sref); + /*! * \brief Make the block inline * \param block_sref The sref of the block diff --git a/python/tvm/tir/schedule.py b/python/tvm/tir/schedule.py index 670446f48b..7fcd98ccda 100644 --- a/python/tvm/tir/schedule.py +++ b/python/tvm/tir/schedule.py @@ -276,6 +276,47 @@ def compute_at(self, block, loop): """ ScheduleComputeAt(self, block, loop) + def reverse_compute_at(self, block, loop): + """Attach one block under specific loop and cover the required region. + Node that only complete block can do reverse_compute_at + + Parameters + ---------- + block: Block + The Block to be reverse_compute_at + + loop: Loop + The target loop + + Example + ------- + .. code-block:: python + + for i0_outer, i1_outer, i0_inner, i1_inner in tir.grid(8, 8, 16, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, ((i0_outer*16) + i0_inner)) + tir.bind(vj, ((i1_outer*16) + i1_inner)) + B[vi, vj] = A[vi, vj] * 2 .0 + with tir.block([128, 128], "C") as [vi, vj]: + C[vi, vj] = B[vi, vj] + 1.0 + + After reverse_compute_at(C, i0_inner) + .. code-block:: python + + for i0_outer, i1_outer, i1_inner in tir.grid(8, 8, 16): + for i1_inner in range(0, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, ((i0_outer*16) + i0_inner)) + tir.bind(vj, ((i1_outer*16) + i1_inner)) + B[vi, vj] = A[vi, vj] * 2.0 + for ax1 in range(0, 16): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, ((i0_outer*16) + i0_inner)) + tir.bind(vj, ((i1_outer*16) + ax1)) + C[vi, vj] = B[vi, vj] + 1.0 + """ + ScheduleReverseComputeAt(self, block, loop) + def bind(self, loop, thread_ivar): """Bind ivar to thread index thread_ivar Parameters diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 00d5d95d6a..65427051c8 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -560,6 +560,12 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeAt") return schedule->compute_at(block_sref, loop_sref); }); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeAt") + .set_body_typed([](Schedule schedule, StmtSRef block_sref, + StmtSRef loop_sref) { + return schedule->reverse_compute_at(block_sref, loop_sref); + }); + TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_typed([](Schedule schedule, StmtSRef block_sref) { return schedule->compute_inline(block_sref); diff --git a/src/tir/schedule/schedule_compute_location.cc b/src/tir/schedule/schedule_compute_location.cc index 5de5dc764f..778df8d0b2 100644 --- a/src/tir/schedule/schedule_compute_location.cc +++ b/src/tir/schedule/schedule_compute_location.cc @@ -203,7 +203,7 @@ Loop RegenerateLoops(const StmtSRef& block_sref, const StmtSRef& loop_sref, int } } // Step 3. Insert the new statement into the children of the loop - Array stmts = GetChildren(GetRef(loop)); + Array stmts = GetChildren(GetRef(loop), true); stmts.insert(stmts.begin() + insert_pos, body); // Step 4. Create a new loop with those statements as new children to substitute loop_sref->stmt ObjectPtr n = make_object(*loop); @@ -218,12 +218,14 @@ Loop RegenerateLoops(const StmtSRef& block_sref, const StmtSRef& loop_sref, int * \param lca_loop_sref The lca of producer and consumer * \param consumer_blocks The consumer consumer_blocks * \param relax_vars The additional vars should be relaxed according to execution scope + * \param gather_read If true(false), gather the read(write) region of consumer_blocks * \return Required with the same order as produce_regions */ std::vector GatherRequirements(const Array& produced_regions, const StmtSRef& lca_loop_sref, const std::vector& consumer_blocks, - const std::unordered_map& relax_vars) { + const std::unordered_map& relax_vars, + bool gather_read) { // For write domain in produce_regions, initiate an empty IntSet for it std::vector> produced_region_reads; for (const TensorRegion& region : produced_regions) { @@ -239,9 +241,13 @@ std::vector GatherRequirements(const Array& produced_region } // For each consumer's reading region for (const StmtSRef& block_sref : consumer_blocks) { - std::vector reads; - RelaxRegion(block_sref, lca_loop_sref, &reads, nullptr, relax_vars); - for (const TensorRegion& region : reads) { + std::vector relaxed; + if (gather_read) { + RelaxRegion(block_sref, lca_loop_sref, &relaxed, nullptr, relax_vars); + } else { + RelaxRegion(block_sref, lca_loop_sref, nullptr, &relaxed, relax_vars); + } + for (const TensorRegion& region : relaxed) { const BufferNode* buffer = region->buffer.get(); if (!buffer_indexer.count(buffer)) { continue; @@ -284,8 +290,8 @@ StmtSRef GetSubTreeOfParent(const StmtSRef& node) { return GetRef(child); } -std::unordered_map RelaxForExeScope(const StmtSRef& loop_sref, - const StmtSRef& block_sref) { +std::unordered_map RelaxForExecScope(const StmtSRef& loop_sref, + const StmtSRef& block_sref) { std::unordered_map relax_var; const auto* block = block_sref->GetStmt(); const BlockRealize& realize = GetBlockRealize(block_sref); @@ -331,10 +337,9 @@ void ScheduleNode::compute_at(const StmtSRef& block_sref, const StmtSRef& loop_s /*! * Check: * - check input_block is complete/is a dominant reduction block - * - check input_block's RAW predecessors are complete * - check dependency: all input_block's RAW successors are under input_loop - * - check one-way fine-grained data flow: all blocks in the same sub tree are complete - * - check block is not a output block + * - check all blocks in the same sub tree are complete + * - check block is not an output block * * Mutate: * - generate loops that iterate the whole instance space under @@ -365,14 +370,14 @@ void ScheduleNode::compute_at(const StmtSRef& block_sref, const StmtSRef& loop_s // Cond 1. 'block' is complete/reduction block CHECK(scope.IsComplete(block_sref) || scope.IsReduction(block_sref)) << "ValueError: 'compute_at' expects 'block' to be a complete or reduction block"; - // Cond 3. Check all RAW successors are in the subtree rooted by loop_sref + // Cond 2. Check all RAW successors are in the subtree rooted by loop_sref CHECK(EachEdgePointsToABlock(edges_to_succ, GetChildBlocks(loop_sref), /*raw_edge_only=*/true)) << "ValueError: 'compute_at' does not apply to a block that some other " << "blocks outside the scope depends on"; - // Cond 4. The subtree has compact data flow + // Cond 3. The subtree has compact data flow CHECK(scope.IsCompactDataFlow(GetSubTreeOfParent(block_sref), this)) << "ValueError: 'compute_at' expects the subtree of 'block' to have compact dataflow"; - // Cond 5. Check the block is not a output block + // Cond 4. Check the block is not a output block for (const TensorRegion& parent_write : parent_block->writes) { for (const TensorRegion& write : block->writes) { CHECK_NE(write->buffer.get(), parent_write->buffer.get()) @@ -406,11 +411,102 @@ void ScheduleNode::compute_at(const StmtSRef& block_sref, const StmtSRef& loop_s // Generate new LoopNode to substitute loop_sref->stmt Loop new_loop = RegenerateLoops( block_sref, loop_sref, insert_pos, - SolveCover(block, - GatherRequirements(/*produced_regions=*/block->writes, - /*lca_loop_sref=*/loop_sref, - /*consumer_blocks=*/EdgesToSRefs(edges_to_succ), - /*relax_vars=*/RelaxForExeScope(loop_sref, block_sref)))); + SolveCover(block, GatherRequirements(/*produced_regions=*/block->writes, + /*lca_loop_sref=*/loop_sref, + /*consumer_blocks=*/EdgesToSRefs(edges_to_succ), + /*relax_vars=*/RelaxForExecScope(loop_sref, block_sref), + /*gather_read=*/true))); + // Remove leaf + std::pair removed = RemoveLeaf(block_sref, this->root); + std::unordered_map replace_map = { + {removed.first.get(), removed.second.get()}, + {loop_sref->stmt, new_loop.get()}, + }; + // Mutate the AST with Replace + StmtSRef lca = LowestCommonAncestor({block_sref, loop_sref}, this->root); + Stmt replaced = StmtReplacer(replace_map)(GetRef(lca->stmt)); + if (const auto* replaced_block = replaced.as()) { + this->Replace(lca, replaced, {{GetRef(replaced_block), GetRef(parent_block)}}); + } else { + this->Replace(lca, replaced); + } +} + +void ScheduleNode::reverse_compute_at(const StmtSRef& block_sref, const StmtSRef& loop_sref) { + /*! + * Check: + * - check input_block is complete/is a dominant reduction block + * - check all input_block's RAW predecessors are under input_loop + * - check all blocks in the same sub tree are complete + * - check all input_block's RAW predecessors are complete/dominant reduction block + * + * Mutate: + * - generate loops that iterate the whole instance space under + * input_loop after all the predecessors + */ + const auto* block = block_sref->GetStmt(); + const auto* loop = loop_sref->GetStmt(); + CHECK(block != nullptr) + << "TypeError: 'reverse_compute_at' expects 'block' to be a block, but get type: " + << block_sref->stmt->GetTypeKey(); + CHECK(loop != nullptr) + << "TypeError: 'reverse_compute_at' expects 'loop' to be a loop, but get type: " + << loop_sref->stmt->GetTypeKey(); + const StmtSRef& parent_block_sref = GetParentBlockSRef(block_sref); + const BlockNode* parent_block = parent_block_sref->GetStmt(); + const Scope& scope = scopes.at(parent_block_sref); + Array edges_to_pred = scope.GetPredecessors(block_sref); + Array edges_to_succ = scope.GetSuccessors(block_sref); + // Cond 0. `block` and `loop` are in the same scope + CHECK_EQ(parent_block_sref.get(), GetParentBlockSRef(loop_sref).get()) + << "ValueError: 'reverse_compute_at' expects 'block' and 'loop' be in the same block"; + // Cond 1. 'block' is complete/reduction block + CHECK(scope.IsComplete(block_sref) || scope.IsReduction(block_sref)) + << "ValueError: 'reverse_compute_at' expects 'block' to be a complete or reduction block"; + // Cond 2. Check all RAW predecessors are in the subtree rooted by loop_sref + CHECK(EachEdgePointsToABlock(edges_to_pred, GetChildBlocks(loop_sref), /*raw_edge_only=*/true)) + << "ValueError: 'reverse_compute_at' does not apply to a block that some other " + << "blocks outside the scope depends on"; + // Cond 3. The subtree has compact data flow + CHECK(scope.IsCompactDataFlow(GetSubTreeOfParent(block_sref), this)) + << "ValueError: 'reverse_compute_at' expects the subtree of 'block' to have compact dataflow"; + // Cond 4. Check all RAW predecessors are complete/reduction block + for (const auto& edge : edges_to_pred) + CHECK(scope.IsComplete(edge->dst) || scope.IsReduction(edge->dst)) + << "ValueError: 'reverse_compute_at' expects producers of 'block' to be a complete or " + "reduction block"; + // Mutation + // Step 1. Find insertion position + int insert_pos; + { + // After all predecessors in dependency graph + Array loop_body = GetChildren(GetRef(loop)); + int n_stmts = loop_body.size(); + for (insert_pos = n_stmts; insert_pos > 0; --insert_pos) { + const StmtNode* stmt = loop_body[insert_pos - 1].get(); + if (AnyEdgePointsToABlock(edges_to_pred, GetChildBlocks(stmt2ref.at(stmt)))) { + break; + } + } + // Before all successors in dep graph. + int before_pos; + for (before_pos = 0; before_pos < n_stmts; before_pos++) { + const StmtNode* stmt = loop_body[before_pos].get(); + if (AnyEdgePointsToABlock(edges_to_succ, GetChildBlocks(stmt2ref.at(stmt)))) { + break; + } + } + CHECK(insert_pos <= before_pos) << "ValueError: 'reverse_compute_at' cannot find an insertion " + "point that satisfies dependency"; + } + // Generate new LoopNode to substitute loop_sref->stmt + Loop new_loop = RegenerateLoops( + block_sref, loop_sref, insert_pos, + SolveCover(block, GatherRequirements(/*produced_regions=*/block->reads, + /*lca_loop_sref=*/loop_sref, + /*consumer_blocks=*/EdgesToSRefs(edges_to_pred), + /*relax_vars=*/{}, + /*gather_read=*/false))); // Remove leaf std::pair removed = RemoveLeaf(block_sref, this->root); std::unordered_map replace_map = { diff --git a/src/tir/schedule/schedule_validate.cc b/src/tir/schedule/schedule_validate.cc index 3c5b881065..93b69ea840 100644 --- a/src/tir/schedule/schedule_validate.cc +++ b/src/tir/schedule/schedule_validate.cc @@ -134,10 +134,10 @@ bool IsAllUniqueVars(const std::vector& list) { * If so, it provides two functions, replace and postproc, for replacing this pattern * and removing them */ -class FuseSplitDetecter : public ExprVisitor { +class FuseSplitDetector : public ExprVisitor { public: /*! \brief Constructor */ - explicit FuseSplitDetecter(std::unordered_map* loop_var_extents) + explicit FuseSplitDetector(std::unordered_map* loop_var_extents) : loop_var_extents(loop_var_extents) {} /*! \brief Check if the PrimExpr is in fuse pattern. If so, set replace and postproc for it */ @@ -260,7 +260,7 @@ class FuseSplitDetecter : public ExprVisitor { class FuseSplitNormalizer : public ExprMutator { public: /*! \brief Constructor */ - explicit FuseSplitNormalizer(const FuseSplitDetecter& detector) : detector_(detector) {} + explicit FuseSplitNormalizer(const FuseSplitDetector& detector) : detector_(detector) {} /*! \brief Destructor. Invoke postproc only if replacement happens at least once. */ ~FuseSplitNormalizer() { if (replaced_) { @@ -280,7 +280,7 @@ class FuseSplitNormalizer : public ExprMutator { private: /*! \brief The detector that has detected some pattern */ - const FuseSplitDetecter& detector_; + const FuseSplitDetector& detector_; /*! \brief Indicating if replacement happens at least once */ bool replaced_ = false; }; @@ -316,7 +316,7 @@ class LoopValidator : public StmtVisitor { std::vector> predicates = SplitPredicate(realize->predicate); for (;;) { // Detect fuse/split pattern - FuseSplitDetecter detector(&loop_vars); + FuseSplitDetector detector(&loop_vars); for (const auto& binding : bindings) { detector(binding); if (detector.replace) { diff --git a/tests/python/tir/test_schedule_primitive.py b/tests/python/tir/test_schedule_primitive.py index 97ce7a0191..23d801475e 100644 --- a/tests/python/tir/test_schedule_primitive.py +++ b/tests/python/tir/test_schedule_primitive.py @@ -147,6 +147,46 @@ def test_compute_at(): assert s.validate_sref() +@tvm.hybrid.script +def reverse_compute_at_element_wise(a: ty.handle, c: ty.handle) -> None: + # function attr dict + C = tir.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = tir.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + B = tir.buffer_allocate([128, 128], elem_offset=0, align=128, offset_factor=1) + + # body + for i0_outer in range(0, 8): + for i1_outer in range(0, 8): + for i0_inner in range(0, 16): + for i1_inner in range(0, 16): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, ((i0_outer*16) + i0_inner)) + tir.bind(vj, ((i1_outer*16) + i1_inner)) + B[vi, vj] = (A[vi, vj]*tir.float32(2)) + for ax1 in range(0, 16): + with tir.block([128, 128], "C") as [vi, vj]: + tir.bind(vi, ((i0_outer*16) + i0_inner)) + tir.bind(vj, ((i1_outer*16) + ax1)) + C[vi, vj] = (B[vi, vj] + tir.float32(1)) + + +def test_reverse_compute_at(): + func = util.element_wise_stmt() + + # schedule + s = tir.create_schedule(func) + B = s.get_block("B") + C = s.get_block("C") + i, j = s.get_axes(B) + i1, i2 = s.split(i, 16) + j1, j2 = s.split(j, 16) + s.reorder(i1, j1, i2, j2) + s.reverse_compute_at(C, i2) + + tvm.ir.assert_structural_equal(reverse_compute_at_element_wise, s.func) + assert s.validate_sref() + + @tvm.hybrid.script def predicate_fuse(b: ty.handle, c: ty.handle) -> None: C = tir.match_buffer(c, (16, 16), "float32") @@ -493,6 +533,7 @@ def test_cache_read_write(): test_fuse_loop_sref() test_reorder_normal() test_compute_at() + test_reverse_compute_at() test_compute_inline() test_compute_at_fail() test_reduction()