Skip to content

Commit

Permalink
[TIR] Utility function to decide loop mapping for auto tensorization (#…
Browse files Browse the repository at this point in the history
…11050)

* [TIR] Add TensorizeInfo and GetTensorizeLoopMapping

* expose PreOrderVisit to python

* add test case

* add conv2d nchwc test

* add mma test

* add arm nhwc conv2d test

* Revert "add arm nhwc conv2d test"

This reverts commit eb147f3.

* refine

* add doc

* update

* fixd condition

* black

* pylint

* Update python/tvm/tir/schedule/analysis.py

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

* run black

* bring back logic in original code to support loop permutation

* add comment

* simplify

* minor fix to test

Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>

Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
  • Loading branch information
6 people authored Apr 20, 2022
1 parent 2025e36 commit 3823b39
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 14 deletions.
33 changes: 32 additions & 1 deletion python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
"""Analysis used in TensorIR scheduling"""
from typing import List, Optional

import tvm._ffi
from tvm.runtime import Object

from ..buffer import Buffer
from ..stmt import For
from ..expr import PrimExpr
from ..function import IndexMap
from ..function import IndexMap, PrimFunc

from . import _ffi_api
from .schedule import Schedule, BlockRV


def suggest_index_map(
Expand Down Expand Up @@ -56,3 +60,30 @@ def suggest_index_map(
loops,
predicate,
)


@tvm._ffi.register_object("tir.schedule.TensorizeInfo")
class TensorizeInfo(Object):
"""Necessary information used for tensorization."""


def get_tensorize_loop_mapping(
sch: Schedule, block: BlockRV, desc_func: PrimFunc
) -> Optional[TensorizeInfo]:
"""Establish a mapping between loops in a target block and an intrinsic description
Parameters
----------
sch : Schedule
The schedule to be tensorized
block : BlockRV
The target block to match against
desc_func : PrimFunc
The prim func describing the computation to be tensorized
Returns
-------
tensorize_info : Optional[TensorizeInfo]
TensorizeInfo structure if a valid mapping is found, None otherwise
"""
return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore
12 changes: 12 additions & 0 deletions python/tvm/tir/stmt_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ def post_order_visit(stmt, fvisit):
return _ffi_api.PostOrderVisit(stmt, fvisit) # type: ignore


def pre_order_visit(stmt, fvisit):
"""Recursive pre-order visit on stmt AST, applying fvisit on each node.
If fvisit returns False, it won't visit the children of the node.
Parameters
----------
fvisit: function of the signature Object -> bool
The visitor function.
"""
return _ffi_api.PreOrderVisit(stmt, fvisit) # type: ignore


def substitute(node, vmap):
"""Substitute the var specified by vmap.
Expand Down
4 changes: 4 additions & 0 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,10 @@ TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, Pack
tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); });
});

TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); });
});

TVM_REGISTER_GLOBAL("tir.Substitute")
.set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef {
if (node->IsInstance<StmtNode>()) {
Expand Down
33 changes: 33 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,39 @@ Array<arith::IntSet> AnalyzeRegionLowerBound(const BufferRegion& region, const P
const StmtSRef& dom_high_exclusive,
arith::Analyzer* analyzer);

/*! \brief Necessary information used for tensorization */
class TensorizeInfoNode : public Object {
public:
/*! \brief Maps loops in a target block to the ones in an intrinsic description */
Map<tir::StmtSRef, tir::For> loop_map;
/*! \brief Maps loops in an intrinsic description to its index, outer to inner */
Map<tir::For, Integer> desc_loop_indexer;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_map", &loop_map);
v->Visit("desc_loop_indexer", &desc_loop_indexer);
}

static constexpr const char* _type_key = "tir.schedule.TensorizeInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object);
};

class TensorizeInfo : public ObjectRef {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode);
};

/*!
* \brief Establish a mapping between loops in a target block and an intrinsic description
* \param self The schedule state to be tensorized
* \param block_sref The target block to match against
* \param desc_func The prim func describing the computation to be tensorized
* \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise
*/
Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func);

} // namespace tir
} // namespace tvm

Expand Down
167 changes: 165 additions & 2 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/container/optional.h>
#include <tvm/tir/expr.h>

#include "../utils.h"

namespace tvm {
Expand Down Expand Up @@ -492,8 +495,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref,
}
}

std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
std::vector<IterVarType> GetBlockVarTypes(const BlockNode* block) {
std::vector<IterVarType> results;
results.reserve(block->iter_vars.size());
for (const IterVar& iter_var : block->iter_vars) {
Expand All @@ -502,6 +504,11 @@ std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
return results;
}

std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
return GetBlockVarTypes(block);
}

bool IsWriteCache(const StmtSRef& block_sref) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
if (block->writes.size() != 1) {
Expand Down Expand Up @@ -2028,5 +2035,161 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
}
}

TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);

Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func) {
arith::Analyzer analyzer;
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
// Step 1. Analyze desc_func, extract its block, loops and loop vars
const tir::BlockRealizeNode* desc_block = nullptr;
std::vector<const tir::ForNode*> desc_loops;
std::unordered_set<const tir::VarNode*> desc_loop_vars;
const auto* desc_scope_realize = desc_func->body.as<tir::BlockRealizeNode>();
ICHECK(desc_scope_realize);
{
auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars,
&analyzer](const ObjectRef& obj) -> bool {
// Extract the block
if (const auto* block = obj.as<tir::BlockRealizeNode>()) {
desc_block = block;
return false;
}
// Extract loops
if (const auto* loop = obj.as<tir::ForNode>()) {
desc_loops.push_back(loop);
desc_loop_vars.insert(loop->loop_var.get());
if (!analyzer.CanProve(loop->min == 0)) {
return false;
}
}
return true;
};
tir::PostOrderVisit(desc_scope_realize->block->body, f_visit);
std::reverse(desc_loops.begin(), desc_loops.end());
ICHECK(desc_block);
}
// Step 2. Collect loops from block_sref
const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false);
const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref);
std::vector<const tir::ForNode*> block_loops;
std::unordered_set<const tir::VarNode*> block_loop_vars;
{
for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) {
const auto* loop = loop_sref->StmtAs<tir::ForNode>();
if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) {
break;
}
block_loops.push_back(loop);
block_loop_vars.insert(loop->loop_var.get());
if (!analyzer.CanProve(loop->min == 0)) {
return NullOpt;
}
}
std::reverse(block_loops.begin(), block_loops.end());
}
// Step 3. Map from block loops to desc block loops
ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>();
const int n_block_vars = block->iter_values.size();
const int n_desc_vars = desc_block->iter_values.size();
const int offset = n_block_vars - n_desc_vars;

if (offset < 0) {
return NullOpt;
}

const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref);
const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get());

ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars));
ICHECK(block_loops.size() == iter_types_block.size());

// We assume that the orders of iter_vars in the target and the desc block are consistent.
// Based on that assumption, the following logic supports arbitrary permutations of a loop order,
// such as

// for k:
// for i:
// for j:
// C[i, j] += A[i, k] * B[k, j]

// or

// for i:
// for j:
// for k:
// C[i, j] += A[i, k] * B[k, j]

int next_block_ind = block_loops.size() - 1;
for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) {
// Step 3.1. Find the corresponding loop of the i_desc-th block var of desc
const PrimExpr& desc_bind = desc_block->iter_values[i_desc];
const tir::ForNode* desc_loop = nullptr;
IterVarType iter_type_desc = iter_types_desc[i_desc];
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
// Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars
PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var);
if (!UsesVar(residual,
[&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) {
desc_loop = desc_loops[i];
iter_type_desc = iter_types_desc[i];
break;
}
}
if (desc_loop == nullptr || desc_loop->extent.as<IntImmNode>() == nullptr) {
return NullOpt;
}

const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>();

// Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type
PrimExpr block_bind;
for (int i = next_block_ind; i >= 0; --i) {
if (iter_types_block[i] == iter_type_desc) {
next_block_ind = i - 1;
block_bind = block->iter_values[i];
break;
}
}

if (!block_bind.defined()) return NullOpt;

// Step 3.3. Find the corresponding loop of the target block
for (int i = 0, n = block_loops.size(); i < n; ++i) {
// Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars
const tir::ForNode* block_loop = block_loops[i];
const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop];
// Skip i-th loop if it has already been mapped
if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue;

PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
if (UsesVar(residual,
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); }))
continue;

const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>();

// Check divisibility
if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) {
return NullOpt;
}

ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
break;
}
}

for (int i = 0, n = desc_loops.size(); i < n; ++i) {
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
}
return TensorizeInfo(ret);
}

TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);
});

} // namespace tir
} // namespace tvm
Loading

0 comments on commit 3823b39

Please sign in to comment.