Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Nov 11, 2021
1 parent 8ee5d4c commit 5546abb
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 32 deletions.
148 changes: 148 additions & 0 deletions src/meta_schedule/postproc/rewrite_reduction_block.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*! \brief The visitor that finds all the reduction block to be decomposed */
struct ReductionBlockFinder : private StmtVisitor {
public:
/*! \brief Find all the reduction blocks that should be decomposed */
static std::vector<std::pair<StmtSRef, String>> Find(const ScheduleState& self) {
std::vector<std::pair<StmtSRef, String>> results;
for (const auto& kv : self->mod->functions) {
GlobalVar g_var = kv.first;
BaseFunc base_func = kv.second;
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
ReductionBlockFinder finder;
finder(prim_func->body);
for (const BlockNode* block : finder.results_) {
results.emplace_back(self->stmt2ref.at(block), g_var->name_hint);
}
}
}
return results;
}

private:
void VisitStmt_(const ForNode* loop) final {
runtime::ThreadScope thread_scope = GetThreadScope(loop);
if (IsThreadIdx(thread_scope) || IsBlockIdx(thread_scope)) {
thread_bound_loop_vars_.insert(loop->loop_var.get());
}
StmtVisitor::VisitStmt_(loop);
}

void VisitStmt_(const BlockRealizeNode* realize) final {
if (realize->block->init.defined() && AllReductionIterVarAreUnbound(realize)) {
results_.push_back(realize->block.get());
}
StmtVisitor::VisitStmt_(realize);
}

bool AllReductionIterVarAreUnbound(const BlockRealizeNode* realize) const {
if (thread_bound_loop_vars_.empty()) {
return true;
}
auto f_find = [this](const VarNode* var) -> bool { return thread_bound_loop_vars_.count(var); };
const BlockNode* block = realize->block.get();
int n = block->iter_vars.size();
for (int i = 0; i < n; ++i) {
IterVar iter_var = block->iter_vars[i];
PrimExpr binding = realize->iter_values[i];
if (iter_var->iter_type == tir::kCommReduce) {
if (UsesVar(binding, f_find)) {
return false;
}
}
}
return true;
}

/*! \brief The results of the collection */
std::vector<const BlockNode*> results_;
/*! \brief Loop variables that are bound to threads */
std::unordered_set<const VarNode*> thread_bound_loop_vars_;
};

/*!
* \brief Find the innermost loop that could be decomposed to
* \param block_sref The block to be decomposed
* \return The index of the innermost loop that could be decomposed
*/
int FindDecomposePoint(const StmtSRef& block_sref) {
Array<StmtSRef> loop_srefs = GetLoops(block_sref);
int n = loop_srefs.size();
for (int i = 0; i < n; ++i) {
if (GetLoopIterType(loop_srefs[i]) != IterVarType::kDataPar) {
return i;
}
}
return -1;
}

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

/*! \brief Rewrite reduction block by moving the init block out */
class RewriteReductionBlockNode : public PostprocNode {
public:
// Inherited from PostprocNode
void InitializeWithTuneContext(const TuneContext& context) final {}
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final;

void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock";
TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode);
};

bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
for (;;) {
std::vector<std::pair<tir::StmtSRef, String>> results =
tir::ReductionBlockFinder::Find(sch->state());
int rewritten = 0;
for (const auto& kv : results) {
const tir::StmtSRef& block_sref = kv.first;
const String& global_var_name = kv.second;
int decompose_point = tir::FindDecomposePoint(block_sref);
if (decompose_point == -1) {
continue;
}
tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]);
++rewritten;
}
if (rewritten == 0) {
break;
}
}
return true;
}

TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode);

} // namespace meta_schedule
} // namespace tvm
43 changes: 11 additions & 32 deletions src/meta_schedule/postproc/rewrite_unbound_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,41 +23,14 @@ namespace tir {

/*! \brief The rewrite type for an unbound block */
enum class BindType : int32_t {
/*! \brief No additional thread binding is needed */
kNoBind = 0,
/*! \brief Need to bind to blockIdx */
kBindBlock = 1,
/*! \brief Need to bind to both blockIdx and threadIdx */
kBindBlockThread = 2,
};

/*!
* \brief Get the thread scope bound to the specific loop
* \param loop The loop to be inspected
* \return The thread scope bound to the loop
*/
runtime::ThreadScope GetThreadScope(const ForNode* loop) {
if (loop->kind == ForKind::kThreadBinding) {
return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag);
}
return runtime::ThreadScope{-1, -1};
}

/*!
* \brief Check if the thread scope is blockIdx
* \param thread_scope The thread scope to be checked
* \return True if the thread scope is blockIdx
*/
bool IsBlockIdx(const runtime::ThreadScope& thread_scope) {
return thread_scope.rank == 0; // The rank of blockIdx is 0
}

/*!
* \brief Check if the thread scope is threadIdx
* \param thread_scope The thread scope to be checked
* \return True if the thread scope is threadIdx
*/
bool IsThreadIdx(const runtime::ThreadScope& thread_scope) {
return thread_scope.rank == 1 && thread_scope.dim_index >= 0;
}

/*!
* \brief Check the combination of bindings to be added to the block
* \param block_sref The block to be checked
Expand Down Expand Up @@ -149,12 +122,17 @@ class UnboundBlockFinder : private StmtVisitor {
}

explicit UnboundBlockFinder(const ScheduleState& self)
: self_{self}, blocks_{}, n_thread_idx_{0}, n_block_idx_{0} {}
: self_{self}, blocks_{}, n_block_idx_{0}, n_thread_idx_{0} {}

/*! \brief The schedule state */
const ScheduleState& self_;
/*! \brief The list of unbound blocks */
std::vector<std::pair<StmtSRef, String>> blocks_;
int n_thread_idx_;
/*! \brief The number of blockIdx above the current stmt */
int n_block_idx_;
/*! \brief The number of threadIdx above the current stmt */
int n_thread_idx_;
/*! \brief The name of the global var */
String global_var_name_;
};

Expand All @@ -179,6 +157,7 @@ class RewriteUnboundBlockNode : public PostprocNode {
bool Apply(const tir::Schedule& sch) final;

public:
/*! \brief The cached warp size from Target */
int warp_size_ = -1;

void VisitAttrs(tvm::AttrVisitor* v) {
Expand Down
7 changes: 7 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,13 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
*/
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);

/*!
* \brief Get the IterVarType of the specific loop, according to the blocks it's bound to
* \param loop_sref The loop to be checked
* \return The IterVarType of the specific loop
*/
IterVarType GetLoopIterType(const StmtSRef& loop_sref);

/******** Producer-consumer relation ********/

/*!
Expand Down
47 changes: 47 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,53 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
}
}

IterVarType GetLoopIterType(const StmtSRef& loop_sref) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
const Var& loop_var = loop->loop_var;
int n_spatial = 0;
int n_reduce = 0;
int n_other = 0;
auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool {
if (const auto* realize = obj.as<BlockRealizeNode>()) {
const BlockNode* block = realize->block.get();
// Number of block vars and their bindings
ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size());
int n = realize->iter_values.size();
for (int i = 0; i < n; ++i) {
const IterVar& iter_var = block->iter_vars[i];
const PrimExpr& binding = realize->iter_values[i];
// Categorize the current block var
int* ref = nullptr;
if (iter_var->iter_type == IterVarType::kDataPar) {
ref = &n_spatial;
} else if (iter_var->iter_type == IterVarType::kCommReduce) {
ref = &n_reduce;
} else {
ref = &n_other;
}
// Visit the binding to see if `loop_var` appears
PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> void {
if (obj.same_as(loop_var)) {
(*ref) += 1;
}
});
}
return false;
}
return true;
};
PreOrderVisit(loop->body, f_visit);
if (n_other) {
return IterVarType::kOpaque;
} else if (n_spatial && n_reduce) {
return IterVarType::kOpaque;
} else if (n_reduce) {
return IterVarType::kCommReduce;
} else {
return IterVarType::kDataPar;
}
}

/******** Producer-consumer relation ********/

Array<StmtSRef> GetProducers(const StmtSRef& block_sref, const BlockScope& scope) {
Expand Down
30 changes: 30 additions & 0 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,36 @@ inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_va
Var(std::move(name), loop->loop_var.dtype()), iter_var_type);
}

/*!
* \brief Get the thread scope bound to the specific loop
* \param loop The loop to be inspected
* \return The thread scope bound to the loop
*/
inline runtime::ThreadScope GetThreadScope(const ForNode* loop) {
if (loop->kind == ForKind::kThreadBinding) {
return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag);
}
return runtime::ThreadScope{-1, -1};
}

/*!
* \brief Check if the thread scope is blockIdx
* \param thread_scope The thread scope to be checked
* \return True if the thread scope is blockIdx
*/
inline bool IsBlockIdx(const runtime::ThreadScope& thread_scope) {
return thread_scope.rank == 0; // The rank of blockIdx is 0
}

/*!
* \brief Check if the thread scope is threadIdx
* \param thread_scope The thread scope to be checked
* \return True if the thread scope is threadIdx
*/
inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) {
return thread_scope.rank == 1 && thread_scope.dim_index >= 0;
}

/******** Integer set ********/

/*!
Expand Down

0 comments on commit 5546abb

Please sign in to comment.