diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 613d28f7bf..2e06799336 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -883,7 +883,7 @@ CommReducer::CommReducer(Array lhs, Array rhs, Array result, data_ = std::move(node); } -`bool CommReducer::MatchReducer(const CommReducer& reducer, const PrimExpr& identity, +bool CommReducer::MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner, Optional& lhs, Optional& rhs) { ExprDeepEqual equal; diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 18be68e120..f06a078565 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -24,6 +24,8 @@ #include #include +#include "../../runtime/thread_storage_scope.h" + namespace tvm { namespace tir { diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d75818f088..9af80bec4c 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -216,7 +216,7 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, } return true; }); - return affected; + return !affected; } /******** Binding ********/ @@ -523,40 +523,6 @@ BufferRegion SubstituteBufferRegion(const BufferRegion& buffer_region, return BufferRegion(new_buffer_region); } -/******** Block Information Update ********/ - -void UpdateScope(ScheduleState self, const StmtSRef& block_sref) { - BlockScope scope(tir::GetChildBlocks(self, block_sref)); - // The caller is responsible for correcting the flags - bool affine_binding = false; - bool region_cover = false; - self->block_info[block_sref] = - BlockInfo(std::move(scope), affine_binding, region_cover); -} - -void UpdateAffineFlag(ScheduleState self, const StmtSRef& block_sref) { - if (block_sref->parent == nullptr) { - ICHECK(self->block_info.count(block_sref)); - self->block_info[block_sref].affine_binding = true; - return; - } - BlockRealize realize = GetBlockRealize(block_sref); - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); - Map loop_var_ranges; - for (StmtSRefNode* loop_sref = block_sref->parent; loop_sref != nullptr; - loop_sref = loop_sref->parent) { - if (const auto* loop = loop_sref->StmtAs()) { - loop_var_ranges.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); - } else { - break; - } - } - ICHECK(self->block_info.count(block_sref)); - arith::Analyzer analyzer; - self->block_info[block_sref].affine_binding = - IsAffineBinding(realize, loop_var_ranges, &analyzer); -} - /******** Pattern Matcher ********/ void PatternMatcher::VisitExpr_(const VarNode* op) { diff --git a/src/tir/schedule/primitives/reduction.cc b/src/tir/schedule/primitives/reduction.cc index 20251c001d..6816a6dc51 100644 --- a/src/tir/schedule/primitives/reduction.cc +++ b/src/tir/schedule/primitives/reduction.cc @@ -40,9 +40,8 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) BlockRealize block_realize = GetBlockRealize(block_sref); Block block = block_realize->block; Optional scope_root = GetScopeRoot(block_sref); -// Todo: comment out the two lines below out after Junru's PR getting merged -// CHECK(IsReductionBlock(self, block_sref, scope_root.value())) -// << "ValueError: We can only do rfactor for loops of a reduction block"; + CHECK(IsReductionBlock(self, block_sref, scope_root.value())) + << "ValueError: We can only do rfactor for loops of a reduction block"; // Collect the information of the reduction. // Get the `init` identity and the `update` combiner of the reduction. @@ -390,10 +389,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis) self->Replace(scope_root.value(), new_scope_block, {{scope_block, new_scope_block}}); // Update scope information. StmtSRef rf_block_sref = self->stmt2ref.at(rf_block.get()); - UpdateScope(self, scope_root.value()); - UpdateAffineFlag(self, scope_root.value()); - UpdateAffineFlag(self, rf_block_sref); - // Todo: in which cases should we call UpdateScope & UpdateAffineFlag? + self->block_info[rf_block_sref].affine_binding = true; + self->block_info[rf_block_sref].region_cover = true; + self->block_info[rf_block_sref].scope->stage_pipeline = true; return rf_block_sref; }