Skip to content

Commit

Permalink
clean up headers
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent d8b2aa3 commit 2fc118b
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 165 deletions.
2 changes: 0 additions & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
*/
#include "multi_level_tiling.h"

#include <unordered_map>

#include "../utils.h"

namespace tvm {
Expand Down
9 changes: 7 additions & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <unordered_map>
#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_
#define TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_

#include "../utils.h"
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/tir/schedule/schedule.h>
#include "../../support/array.h"

namespace tvm {
namespace meta_schedule {
Expand Down Expand Up @@ -206,3 +209,5 @@ ScheduleRule MultiLevelTilingInitCommon(String structure, Optional<Array<String>

} // namespace meta_schedule
} // namespace tvm

#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_
164 changes: 3 additions & 161 deletions src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,180 +16,22 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <unordered_map>

#include "../utils.h"
#include "multi_level_tiling.h"
#include "../../tir/schedule/analysis.h"

namespace tvm {
namespace meta_schedule {

using tir::LoopRV;

/*! \brief Necessary information used for tensorization */
class TensorizeInfoNode : public Object {
public:
/*! \brief Maps block loops to desc loops */
Map<tir::StmtSRef, tir::For> loop_map;
/*! \brief Maps loops in desc to its index, outer to inner */
Map<tir::For, Integer> desc_loop_indexer;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_map", &loop_map);
v->Visit("desc_loop_indexer", &desc_loop_indexer);
}

static constexpr const char* _type_key = "tir.analysis.TensorizeInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object);
};

class TensorizeInfo : public ObjectRef {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode);
};

TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);

Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func) {
// Try to do tiling automatically if possible
// Now the heuristic is that if block's block var binding is constant + loop var,
// in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder
// i, j, k according to the loops outside desc_block
// Collect the loops outside block
arith::Analyzer analyzer;
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
// Step 1. Analyze desc_func, extract its block, loops and loop vars
const tir::BlockRealizeNode* desc_block = nullptr;
std::vector<const tir::ForNode*> desc_loops;
std::unordered_set<const tir::VarNode*> desc_loop_vars;
const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
ICHECK(desc_scope_realize);
{
auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
&analyzer](const ObjectRef& obj) -> bool {
// Extract the block
if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
desc_block = block;
return false;
}
// Extract the loops
if (const auto* loop = obj.as<tir::ForNode>()) {
desc_loops.push_back(loop);
desc_loop_vars.insert(loop->loop_var.get());
if (!analyzer.CanProve(loop->min == 0)) {
return false;
}
}
return true;
};
tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
std::reverse(desc_loops.begin(), desc_loops.end());
ICHECK(desc_block);
}
// Step 2. Check if `desc_block` matches `block`
// Ignore the scope of buffers when comparing, since we can do cache_read/write
const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
std::vector<const tir::ForNode*> block_loops;
std::unordered_set<const tir::VarNode*> block_loop_vars;
{
for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
const auto* loop = loop_sref->StmtAs<tir::ForNode>();
if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
break;
}
block_loops.push_back(loop);
block_loop_vars.insert(loop->loop_var.get());
if (!analyzer.CanProve(loop->min == 0)) {
return NullOpt;
}
}
std::reverse(block_loops.begin(), block_loops.end());
}
// Step 4. Map from block loops to desc block loops
ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
int n_block_vars = block->iter_values.size();
int n_desc_vars = desc_block->iter_values.size();
int offset = n_block_vars - n_desc_vars;
if (offset < 0) {
return NullOpt;
}
// We align the block and desc block's bindings from the right side
// block (v0=..., v1=..., v2=...)
// ^ i_block
// desc_block( v1=..., v2=...)
// ^ i_desc
for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) {
// For each block var binding, we find
const PrimExpr& block_bind = block->iter_values[i_block];
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
// Step 4.1. Find the corresponding loop of the i-th block var of block
const tir::ForNode* block_loop = nullptr;
for (int i = 0, n = block_loops.size(); i < n; ++i) {
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) {
return block_loop_vars.count(var);
})) {
block_loop = block_loops[i];
break;
}
}
if (block_loop == nullptr) {
return NullOpt;
}
// Step 4.2. Find the corresponding loop of the i-th block var of desc
const tir::ForNode* desc_loop = nullptr;
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
if (!tir::UsesVar(r, [&desc_loop_vars](const tir::VarNode* var) {
return desc_loop_vars.count(var);
})) {
desc_loop = desc_loops[i];
break;
}
}
if (block_loop == nullptr) {
return NullOpt;
}
// Step 4.3. Check divisibility of loop extents
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
if (const auto* int_block_extent = block_extent.as<IntImmNode>()) {
if (const auto* int_desc_extent = desc_extent.as<IntImmNode>()) {
if (int_block_extent->value % int_desc_extent->value != 0) {
return NullOpt;
}
} else {
return NullOpt;
}
} else {
return NullOpt;
}
// Step 4.4. Maps the result of Step 4.1 to Step 4.2
const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
auto it = ret->loop_map.find(block_loop_sref);
if (it == ret->loop_map.end()) {
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
} else if ((*it).second.get() != desc_loop) {
return NullOpt;
}
}
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
}
return TensorizeInfo(ret);
}

Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name) {
Optional<TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
if (!opt_tensorize_info) return NullOpt;
const TensorizeInfoNode* info = opt_tensorize_info.value().get();
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
// Construct a mapping from tir loops back to LoopRVs
Map<tir::StmtSRef, LoopRV> loop2rv;
{
Expand Down
26 changes: 26 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,32 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const P
const StmtSRef& dom_high_exclusive,
arith::Analyzer* analyzer);

/*! \brief Necessary information used for tensorization */
class TensorizeInfoNode : public Object {
public:
/*! \brief Maps block loops to desc loops */
Map<tir::StmtSRef, tir::For> loop_map;
/*! \brief Maps loops in desc to its index, outer to inner */
Map<tir::For, Integer> desc_loop_indexer;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_map", &loop_map);
v->Visit("desc_loop_indexer", &desc_loop_indexer);
}

static constexpr const char* _type_key = "tir.analysis.TensorizeInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object);
};

class TensorizeInfo : public ObjectRef {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode);
};

Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func);

} // namespace tir
} // namespace tvm

Expand Down
136 changes: 136 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1992,5 +1992,141 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
}
}

TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);

Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func) {
// Try to do tiling automatically if possible
// Now the heuristic is that if block's block var binding is constant + loop var,
// in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder
// i, j, k according to the loops outside desc_block
// Collect the loops outside block
arith::Analyzer analyzer;
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
// Step 1. Analyze desc_func, extract its block, loops and loop vars
const tir::BlockRealizeNode* desc_block = nullptr;
std::vector<const tir::ForNode*> desc_loops;
std::unordered_set<const tir::VarNode*> desc_loop_vars;
const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
ICHECK(desc_scope_realize);
{
auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
&analyzer](const ObjectRef& obj) -> bool {
// Extract the block
if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
desc_block = block;
return false;
}
// Extract the loops
if (const auto* loop = obj.as<tir::ForNode>()) {
desc_loops.push_back(loop);
desc_loop_vars.insert(loop->loop_var.get());
if (!analyzer.CanProve(loop->min == 0)) {
return false;
}
}
return true;
};
tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
std::reverse(desc_loops.begin(), desc_loops.end());
ICHECK(desc_block);
}
// Step 2. Check if `desc_block` matches `block`
// Ignore the scope of buffers when comparing, since we can do cache_read/write
const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
std::vector<const tir::ForNode*> block_loops;
std::unordered_set<const tir::VarNode*> block_loop_vars;
{
for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
const auto* loop = loop_sref->StmtAs<tir::ForNode>();
if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
break;
}
block_loops.push_back(loop);
block_loop_vars.insert(loop->loop_var.get());
if (!analyzer.CanProve(loop->min == 0)) {
return NullOpt;
}
}
std::reverse(block_loops.begin(), block_loops.end());
}
// Step 4. Map from block loops to desc block loops
ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
int n_block_vars = block->iter_values.size();
int n_desc_vars = desc_block->iter_values.size();
int offset = n_block_vars - n_desc_vars;
if (offset < 0) {
return NullOpt;
}
// We align the block and desc block's bindings from the right side
// block (v0=..., v1=..., v2=...)
// ^ i_block
// desc_block( v1=..., v2=...)
// ^ i_desc
for (int i_desc = 0, i_block = offset; i_desc < n_desc_vars; ++i_desc, ++i_block) {
// For each block var binding, we find
const PrimExpr& block_bind = block->iter_values[i_block];
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
// Step 4.1. Find the corresponding loop of the i-th block var of block
const tir::ForNode* block_loop = nullptr;
for (int i = 0, n = block_loops.size(); i < n; ++i) {
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
if (!tir::UsesVar(r, [&block_loop_vars](const tir::VarNode* var) {
return block_loop_vars.count(var);
})) {
block_loop = block_loops[i];
break;
}
}
if (block_loop == nullptr) {
return NullOpt;
}
// Step 4.2. Find the corresponding loop of the i-th block var of desc
const tir::ForNode* desc_loop = nullptr;
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
if (!tir::UsesVar(r, [&desc_loop_vars](const tir::VarNode* var) {
return desc_loop_vars.count(var);
})) {
desc_loop = desc_loops[i];
break;
}
}
if (block_loop == nullptr) {
return NullOpt;
}
// Step 4.3. Check divisibility of loop extents
PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
if (const auto* int_block_extent = block_extent.as<IntImmNode>()) {
if (const auto* int_desc_extent = desc_extent.as<IntImmNode>()) {
if (int_block_extent->value % int_desc_extent->value != 0) {
return NullOpt;
}
} else {
return NullOpt;
}
} else {
return NullOpt;
}
// Step 4.4. Maps the result of Step 4.1 to Step 4.2
const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
auto it = ret->loop_map.find(block_loop_sref);
if (it == ret->loop_map.end()) {
ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
} else if ((*it).second.get() != desc_loop) {
return NullOpt;
}
}
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
}
return TensorizeInfo(ret);
}

} // namespace tir
} // namespace tvm

0 comments on commit 2fc118b

Please sign in to comment.