Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Meta Schedule][BugFix] Fix PostProc RewriteReductionBlock #487

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 23 additions & 33 deletions src/meta_schedule/space/postproc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,13 @@ class PostprocRewriteReductionBlock {
Finder finder;
finder.CollectBoundLoops(stmt);
finder.VisitStmt(stmt);
ICHECK(finder.result_ == nullptr || (finder.result_->init.defined() &&
ICHECK(finder.result_ == nullptr || (finder.result_->block->init.defined() &&
finder.AllReductionIterVarAreUnbound(finder.result_)));
return finder.result_;
return finder.result_->block.get();
}

private:
Finder() : bound_loop_vars_(), stack_(), result_(nullptr) {}
Finder() : bound_loop_vars_(), result_(nullptr) {}

/*!
* \brief Collect all the loops inside `stmt` that are bound to threadIdx or blockIdx.
Expand All @@ -683,30 +683,31 @@ class PostprocRewriteReductionBlock {
/*!
* \brief Check whether the two following conditions are both satisfied:
* 1. the block has at least one reduction block var, and
* 2. none of its reduction block var bindings is bound to threadIdx.
* \param block_sref The block to be checked
* \return A boolean indicating if it has at least one reduction block var
* 2. none of its reduction block var bindings is bound to `threadIdx.x/y/z` or
* `blockIdx.x/y/z`
* \param block_realize The block-realize of the block to be checked
* \return A boolean indicating if the above conditions are satisfied
*/
bool AllReductionIterVarAreUnbound(const tir::BlockNode* block) {
bool AllReductionIterVarAreUnbound(const tir::BlockRealizeNode* block_realize) {
bool has_reduction_var = false;
CHECK(!stack_.empty() && stack_.back()->block.get() == block)
<< "ValueError: the block has outer BlockRealize or the outer BlockRealize doesn't match "
"the block.";
const tir::BlockRealize& block_realize = GetRef<tir::BlockRealize>(stack_.back());
const tir::Block& block = block_realize->block;
ICHECK_EQ(block_realize->iter_values.size(), block->iter_vars.size());
for (int i = 0; i < static_cast<int>(block->iter_vars.size()); ++i) {
for (int i = 0, n = static_cast<int>(block->iter_vars.size()); i < n; ++i) {
const tir::IterVar& var = block->iter_vars[i];
const PrimExpr& binding = block_realize->iter_values[i];
if (var->iter_type == tir::kCommReduce &&
tir::UsesVar(binding,
if (var->iter_type != tir::kCommReduce) {
continue;
}
if (tir::UsesVar(binding,
[&](const tir::VarNode* node) { return bound_loop_vars_.count(node); })) {
return false;
}
has_reduction_var = true;
}
return has_reduction_var;
}

void VisitStmt_(const tir::BlockNode* block) override {
void VisitStmt_(const tir::BlockRealizeNode* block_realize) override {
if (result_ != nullptr) {
return;
}
Expand All @@ -715,32 +716,21 @@ class PostprocRewriteReductionBlock {
* 2. If some of its reduction block var bindings are bound to threadIdx, this indicates
* that cross-thread-reduction is needed, and hence we should not decompose the init block.
*/
if (block->init.defined() && AllReductionIterVarAreUnbound(block)) {
result_ = block;
if (block_realize->block->init.defined() && AllReductionIterVarAreUnbound(block_realize)) {
result_ = block_realize;
} else {
tir::StmtVisitor::VisitStmt_(block);
}
}

void VisitStmt_(const tir::BlockRealizeNode* block_realize) override {
if (result_ != nullptr) {
return;
tir::StmtVisitor::VisitStmt_(block_realize);
}
stack_.push_back(block_realize);
tir::StmtVisitor::VisitStmt_(block_realize);
ICHECK(!stack_.empty());
stack_.pop_back();
}

/*! \brief A set recording all the bound loop vars. */
std::unordered_set<const tir::VarNode*> bound_loop_vars_;
/*! \brief A stack recording all the BlockRealizes along the visiting path. */
std::vector<const tir::BlockRealizeNode*> stack_;
/*!
* \brief The result block which has at least one reduction block var and none of the block var
* bindings is bound to threadIdx (i.e., cross-thread-reduction is not needed).
* \brief The block-realize of the result block which has at least one reduction block var and
* none of the block var bindings is bound to threadIdx (i.e., cross-thread-reduction is not
* needed).
*/
const tir::BlockNode* result_;
const tir::BlockRealizeNode* result_;
};

bool Proc(const Schedule& sch) const {
Expand Down