Skip to content

Commit

Permalink
[MetaSchedule] Support padding for irregular shapes for CUDA tensor core
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Sep 15, 2022
1 parent 1f8b5de commit 1b92747
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 17 deletions.
5 changes: 3 additions & 2 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class TensorizeInfo(Object):


def get_tensorize_loop_mapping(
sch: Schedule, block: BlockRV, desc_func: PrimFunc
sch: Schedule, block: BlockRV, desc_func: PrimFunc, allow_padding: bool
) -> Optional[TensorizeInfo]:
"""Establish a mapping between loops in a target block and an intrinsic description
Expand All @@ -80,7 +80,8 @@ def get_tensorize_loop_mapping(
The target block to match against
desc_func : PrimFunc
The prim func describing the computation to be tensorized
allow_padding : bool
Whether to allow padding the block iters to match the intrinsic description
Returns
-------
tensorize_info : Optional[TensorizeInfo]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,8 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
state->sch->TransformBlockLayout(state->tensor_core_reindex_B, index_map);
state->sch->TransformBlockLayout(state->block_rv, index_map);

return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name);
return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name,
/*allow_padding=*/true);
}

inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorization(
Expand Down
8 changes: 7 additions & 1 deletion src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -731,10 +731,15 @@ class TensorizeInfoNode : public Object {
Map<tir::StmtSRef, tir::For> loop_map;
/*! \brief Maps loops in an intrinsic description to its index, outer to inner */
Map<tir::For, Integer> desc_loop_indexer;
/*! \brief Optional padded extents of the block iters when padding is needed to match the
* intrinsic description
*/
Optional<Array<Integer>> block_iter_paddings;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_map", &loop_map);
v->Visit("desc_loop_indexer", &desc_loop_indexer);
v->Visit("block_iter_paddings", &block_iter_paddings);
}

static constexpr const char* _type_key = "tir.schedule.TensorizeInfo";
Expand All @@ -751,11 +756,12 @@ class TensorizeInfo : public ObjectRef {
* \param self The schedule state to be tensorized
* \param block_sref The target block to match against
* \param desc_func The prim func describing the computation to be tensorized
* \param allow_padding Whether to allow padding the block iters to match the intrinsic description
* \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise
*/
Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func);
const tir::PrimFunc& desc_func, bool allow_padding);

/*!\brief Necessary information used to perform transformations for tensorization */
class AutoTensorizeMappingInfoNode : public Object {
Expand Down
53 changes: 44 additions & 9 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,8 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer,

Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func) {
const tir::PrimFunc& desc_func,
bool allow_padding) {
arith::Analyzer analyzer;
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
// Step 1. Analyze desc_func, extract its block, loops and loop vars
Expand Down Expand Up @@ -1732,6 +1733,8 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const int n_desc_vars = desc_block->iter_values.size();
const int offset = n_block_vars - n_desc_vars;

std::unordered_map<int, int> block_index_to_padding; // padding of each block iter if necessary

if (offset < 0) {
return NullOpt;
}
Expand Down Expand Up @@ -1782,10 +1785,11 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,

// Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type
PrimExpr block_bind;
for (int i = next_block_ind; i >= 0; --i) {
if (iter_types_block[i] == iter_type_desc) {
next_block_ind = i - 1;
block_bind = block->iter_values[i];
int current_block_ind = next_block_ind;
for (; current_block_ind >= 0; --current_block_ind) {
if (iter_types_block[current_block_ind] == iter_type_desc) {
next_block_ind = current_block_ind - 1;
block_bind = block->iter_values[current_block_ind];
break;
}
}
Expand All @@ -1802,15 +1806,30 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,

PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
if (UsesVar(residual,
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); }))
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) {
continue;
}
// padding is allowed only when the block has trivial bindings
if (allow_padding && !is_zero(residual)) {
allow_padding = false;
}

const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>();

// Check divisibility
if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) {
if (!int_block_extent) {
return NullOpt;
}
int64_t remainder = int_block_extent->value % int_desc_extent->value;
if (remainder != 0) {
if (allow_padding) {
// If the block loop is not divisible by the desc loop, we pad the block loop to make it
// divisible if padding is allowed.
block_index_to_padding[current_block_ind] = int_desc_extent->value - remainder;
} else {
return NullOpt;
}
}

ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
break;
Expand All @@ -1820,13 +1839,29 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
}
if (!block_index_to_padding.empty()) {
if (!allow_padding) {
return NullOpt;
}
Array<Integer> paddings;
for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) {
const IterVar& iter_var = block->block->iter_vars[i];
if (auto it = block_index_to_padding.find(i); it != block_index_to_padding.end()) {
paddings.push_back(IntImm(iter_var->var.dtype(), it->second));
} else {
paddings.push_back(IntImm(iter_var->var.dtype(), 0));
}
}
ret->block_iter_paddings = std::move(paddings);
}

return TensorizeInfo(ret);
}

TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc);
TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) {
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding);
});

/******** Auto Tensorization ********/
Expand Down
7 changes: 5 additions & 2 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,14 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
}

Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name) {
const String& intrin_name, bool allow_padding) {
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc, allow_padding);
if (!opt_tensorize_info) return NullOpt;
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
if (info->block_iter_paddings.defined()) {
sch->PadEinsum(block_rv, info->block_iter_paddings.value());
}
// Construct a mapping from tir loops back to LoopRVs
Map<tir::StmtSRef, LoopRV> loop2rv;
{
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
* block tiled according to the given intrin, NullOpt if a valid loop mapping is not found
*/
Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name);
const String& intrin_name, bool allow_padding = false);

/******** Block mutation ********/

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,122 @@ def test_cuda_tensor_core_matmul_relu():
check_trace(spaces, expected)


def test_cuda_tensor_core_padded_matmul_relu():
m = n = k = 127
target = Target("cuda", host="llvm")
ctx = _create_context(
create_prim_func(
te_workload.matmul_relu(
n=n,
m=m,
k=k,
in_dtype="float16",
out_dtype="float32",
)
),
target=target,
rule=[
multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"),
auto_inline(target),
],
)
spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
assert len(spaces) == 1

expected = [
"""b0 = sch.get_block(name="C", func_name="main")
b1 = sch.get_block(name="compute", func_name="main")
sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")
b2 = sch.reindex(block=b0, buffer=("write", 0))
b3 = sch.reindex(block=b0, buffer=("read", 0))
b4 = sch.reindex(block=b0, buffer=("read", 1))
sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, ))
sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, ))
sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, ))
sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, ))
sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, ))
sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, ))
sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, ))
sch.pad_einsum(block=b0, padding=[1, 1, 1])
l5, l6, l7 = sch.get_loops(block=b0)
l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True)
l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True)
l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True)
l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0)
sch.reorder(l16, l18, l13, l11, l9)
b20 = sch.blockize(loop=l13)
sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32")
sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32")
sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1)
l21, l22, l23 = sch.get_loops(block=b20)
v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4)
l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True)
v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4)
l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True)
v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4)
l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True)
sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43)
l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
sch.bind(loop=l50, thread_axis="blockIdx.y")
l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
sch.bind(loop=l51, thread_axis="blockIdx.x")
l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
sch.bind(loop=l52, thread_axis="threadIdx.y")
b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared")
sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1)
b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator")
sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1)
v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)
sch.reverse_compute_inline(block=b2)
l56, l57, l58, l59, l60 = sch.get_loops(block=b54)
l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True)
l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True)
l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54)
sch.reorder(l70, l64, l62)
b72 = sch.blockize(loop=l64)
sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared")
b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared")
sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1)
l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73)
l80 = sch.fuse(l78, l79, preserve_unit_iters=True)
v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81)
b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared")
sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1)
l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82)
l89 = sch.fuse(l87, l88, preserve_unit_iters=True)
v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25])
sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90)
b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a")
sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1)
l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91)
l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True)
l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True)
l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91)
sch.reorder(l110, l102, l100)
b112 = sch.blockize(loop=l102)
sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a")
b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b")
sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1)
l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113)
l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True)
l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True)
l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113)
sch.reorder(l132, l124, l122)
b134 = sch.blockize(loop=l124)
sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b")
sch.compute_inline(block=b3)
sch.compute_inline(block=b4)
sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8)
sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8)
sch.reverse_compute_inline(block=b1)""".split(
"\n"
)
]
check_trace(spaces, expected)


def test_cuda_tensor_core_software_pipeline_matmul_relu():
m = n = k = 128
target = Target("cuda", host="llvm")
Expand Down
26 changes: 25 additions & 1 deletion tests/python/unittest/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tvm.testing
from tvm.tir.function import TensorIntrin
from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc
from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f32_INTRIN
from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f16_INTRIN, WMMA_SYNC_16x16x16_f16f16f32_INTRIN


from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule
Expand Down Expand Up @@ -260,6 +260,30 @@ def matmul_16x16x16xf16f16f16_desc(
assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2)


def test_get_tensorize_loop_mapping_padding_matmul():
matmul = create_prim_func(
te_workload.matmul_relu(
n=127,
m=256,
k=65,
in_dtype="float16",
out_dtype="float16",
)
)
s = Schedule(matmul)
block = s.get_block("C")

desc = TensorIntrin.get(WMMA_SYNC_16x16x16_f16f16f16_INTRIN).desc
info = get_tensorize_loop_mapping(s, block, desc)
assert info is not None
expected_padding = [1, 0, 15]
actual_padding = info.block_iter_paddings
assert actual_padding is not None
assert len(actual_padding) == len(expected_padding)
for actual, expected in zip(actual_padding, expected_padding):
assert actual == expected


def check_index_map(workload, block_name, intrin_name, expected_index_map):
s = Schedule(workload)
block = s.get_block(block_name)
Expand Down

0 comments on commit 1b92747

Please sign in to comment.