From 1b92747e7390af6aa5b5250c15792e369c52bda1 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 6 Sep 2022 17:18:05 -0700 Subject: [PATCH] [MetaSchedule] Support padding for irregular shapes for CUDA tensor core --- python/tvm/tir/schedule/analysis.py | 5 +- .../multi_level_tiling_tensor_core.cc | 3 +- src/tir/schedule/analysis.h | 8 +- src/tir/schedule/analysis/analysis.cc | 53 ++++++-- src/tir/schedule/transform.cc | 7 +- src/tir/schedule/transform.h | 2 +- ...hedule_schedule_rule_multi_level_tiling.py | 116 ++++++++++++++++++ .../unittest/test_tir_schedule_analysis.py | 26 +++- 8 files changed, 203 insertions(+), 17 deletions(-) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index cdb4aa9cfa20f..9f36bd4b8c273 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -68,7 +68,7 @@ class TensorizeInfo(Object): def get_tensorize_loop_mapping( - sch: Schedule, block: BlockRV, desc_func: PrimFunc + sch: Schedule, block: BlockRV, desc_func: PrimFunc, allow_padding: bool ) -> Optional[TensorizeInfo]: """Establish a mapping between loops in a target block and an intrinsic description @@ -80,7 +80,8 @@ def get_tensorize_loop_mapping( The target block to match against desc_func : PrimFunc The prim func describing the computation to be tensorized - + allow_padding : bool + Whether to allow padding the block iters to match the intrinsic description Returns ------- tensorize_info : Optional[TensorizeInfo] diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 7ddda9b2635b2..1d75abfb92cf8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -508,7 +508,8 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( state->sch->TransformBlockLayout(state->tensor_core_reindex_B, index_map); state->sch->TransformBlockLayout(state->block_rv, index_map); - return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name); + return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name, + /*allow_padding=*/true); } inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorization( diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index ca45bcac6b344..57165fd08ad44 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -731,10 +731,15 @@ class TensorizeInfoNode : public Object { Map loop_map; /*! \brief Maps loops in an intrinsic description to its index, outer to inner */ Map desc_loop_indexer; + /*! \brief Optional padded extents of the block iters when padding is needed to match the + * intrinsic description + */ + Optional> block_iter_paddings; void VisitAttrs(AttrVisitor* v) { v->Visit("loop_map", &loop_map); v->Visit("desc_loop_indexer", &desc_loop_indexer); + v->Visit("block_iter_paddings", &block_iter_paddings); } static constexpr const char* _type_key = "tir.schedule.TensorizeInfo"; @@ -751,11 +756,12 @@ class TensorizeInfo : public ObjectRef { * \param self The schedule state to be tensorized * \param block_sref The target block to match against * \param desc_func The prim func describing the computation to be tensorized + * \param allow_padding Whether to allow padding the block iters to match the intrinsic description * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise */ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func); + const tir::PrimFunc& desc_func, bool allow_padding); /*!\brief Necessary information used to perform transformations for tensorization */ class AutoTensorizeMappingInfoNode : public Object { diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4f78b0c9cd43e..b2f879c60e29b 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1699,7 +1699,8 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func) { + const tir::PrimFunc& desc_func, + bool allow_padding) { arith::Analyzer analyzer; const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars @@ -1732,6 +1733,8 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const int n_desc_vars = desc_block->iter_values.size(); const int offset = n_block_vars - n_desc_vars; + std::unordered_map block_index_to_padding; // padding of each block iter if necessary + if (offset < 0) { return NullOpt; } @@ -1782,10 +1785,11 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type PrimExpr block_bind; - for (int i = next_block_ind; i >= 0; --i) { - if (iter_types_block[i] == iter_type_desc) { - next_block_ind = i - 1; - block_bind = block->iter_values[i]; + int current_block_ind = next_block_ind; + for (; current_block_ind >= 0; --current_block_ind) { + if (iter_types_block[current_block_ind] == iter_type_desc) { + next_block_ind = current_block_ind - 1; + block_bind = block->iter_values[current_block_ind]; break; } } @@ -1802,15 +1806,30 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var); if (UsesVar(residual, - [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) + [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { continue; + } + // padding is allowed only when the block has trivial bindings + if (allow_padding && !is_zero(residual)) { + allow_padding = false; + } const IntImmNode* int_block_extent = block_loops[i]->extent.as(); // Check divisibility - if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { + if (!int_block_extent) { return NullOpt; } + int64_t remainder = int_block_extent->value % int_desc_extent->value; + if (remainder != 0) { + if (allow_padding) { + // If the block loop is not divisible by the desc loop, we pad the block loop to make it + // divisible if padding is allowed. + block_index_to_padding[current_block_ind] = int_desc_extent->value - remainder; + } else { + return NullOpt; + } + } ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); break; @@ -1820,13 +1839,29 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, for (int i = 0, n = desc_loops.size(); i < n; ++i) { ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); } + if (!block_index_to_padding.empty()) { + if (!allow_padding) { + return NullOpt; + } + Array paddings; + for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) { + const IterVar& iter_var = block->block->iter_vars[i]; + if (auto it = block_index_to_padding.find(i); it != block_index_to_padding.end()) { + paddings.push_back(IntImm(iter_var->var.dtype(), it->second)); + } else { + paddings.push_back(IntImm(iter_var->var.dtype(), 0)); + } + } + ret->block_iter_paddings = std::move(paddings); + } + return TensorizeInfo(ret); } TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc); TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") - .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { - return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); + .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) { + return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); }); /******** Auto Tensorization ********/ diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index dfbd3dbcbcc4e..b0aea49bf0a89 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -288,11 +288,14 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, - const String& intrin_name) { + const String& intrin_name, bool allow_padding) { Optional opt_tensorize_info = GetTensorizeLoopMapping( - sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); + sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc, allow_padding); if (!opt_tensorize_info) return NullOpt; const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); + if (info->block_iter_paddings.defined()) { + sch->PadEinsum(block_rv, info->block_iter_paddings.value()); + } // Construct a mapping from tir loops back to LoopRVs Map loop2rv; { diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 4de3685e24825..eb90ca0139bd1 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -197,7 +197,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ * block tiled according to the given intrin, NullOpt if a valid loop mapping is not found */ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, - const String& intrin_name); + const String& intrin_name, bool allow_padding = false); /******** Block mutation ********/ diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index fe1220c50925d..5d4306674b649 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -709,6 +709,122 @@ def test_cuda_tensor_core_matmul_relu(): check_trace(spaces, expected) +def test_cuda_tensor_core_padded_matmul_relu(): + m = n = k = 127 + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=n, + m=m, + k=k, + in_dtype="float16", + out_dtype="float32", + ) + ), + target=target, + rule=[ + multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"), + auto_inline(target), + ], + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + + expected = [ + """b0 = sch.get_block(name="C", func_name="main") +b1 = sch.get_block(name="compute", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") +b2 = sch.reindex(block=b0, buffer=("write", 0)) +b3 = sch.reindex(block=b0, buffer=("read", 0)) +b4 = sch.reindex(block=b0, buffer=("read", 1)) +sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) +sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, )) +sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) +sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, )) +sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, )) +sch.pad_einsum(block=b0, padding=[1, 1, 1]) +l5, l6, l7 = sch.get_loops(block=b0) +l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True) +l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) +l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) +l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0) +sch.reorder(l16, l18, l13, l11, l9) +b20 = sch.blockize(loop=l13) +sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32") +sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") +sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1) +l21, l22, l23 = sch.get_loops(block=b20) +v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4) +l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True) +v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4) +l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True) +v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4) +l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True) +sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43) +l50 = sch.fuse(l29, l39, preserve_unit_iters=True) +sch.bind(loop=l50, thread_axis="blockIdx.y") +l51 = sch.fuse(l30, l40, preserve_unit_iters=True) +sch.bind(loop=l51, thread_axis="blockIdx.x") +l52 = sch.fuse(l31, l41, preserve_unit_iters=True) +sch.bind(loop=l52, thread_axis="threadIdx.y") +b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared") +sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1) +b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator") +sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1) +v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) +sch.reverse_compute_inline(block=b2) +l56, l57, l58, l59, l60 = sch.get_loops(block=b54) +l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True) +l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True) +l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54) +sch.reorder(l70, l64, l62) +b72 = sch.blockize(loop=l64) +sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared") +b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared") +sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1) +l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73) +l80 = sch.fuse(l78, l79, preserve_unit_iters=True) +v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) +b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared") +sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1) +l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82) +l89 = sch.fuse(l87, l88, preserve_unit_iters=True) +v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) +sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90) +b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a") +sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1) +l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91) +l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True) +l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True) +l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91) +sch.reorder(l110, l102, l100) +b112 = sch.blockize(loop=l102) +sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") +b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b") +sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1) +l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113) +l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True) +l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True) +l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113) +sch.reorder(l132, l124, l122) +b134 = sch.blockize(loop=l124) +sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b") +sch.compute_inline(block=b3) +sch.compute_inline(block=b4) +sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8) +sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8) +sch.reverse_compute_inline(block=b1)""".split( + "\n" + ) + ] + check_trace(spaces, expected) + + def test_cuda_tensor_core_software_pipeline_matmul_relu(): m = n = k = 128 target = Target("cuda", host="llvm") diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 5524abbaf094d..0e6962cb8448a 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -21,7 +21,7 @@ import tvm.testing from tvm.tir.function import TensorIntrin from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc -from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f32_INTRIN +from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f16_INTRIN, WMMA_SYNC_16x16x16_f16f16f32_INTRIN from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule @@ -260,6 +260,30 @@ def matmul_16x16x16xf16f16f16_desc( assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) +def test_get_tensorize_loop_mapping_padding_matmul(): + matmul = create_prim_func( + te_workload.matmul_relu( + n=127, + m=256, + k=65, + in_dtype="float16", + out_dtype="float16", + ) + ) + s = Schedule(matmul) + block = s.get_block("C") + + desc = TensorIntrin.get(WMMA_SYNC_16x16x16_f16f16f16_INTRIN).desc + info = get_tensorize_loop_mapping(s, block, desc) + assert info is not None + expected_padding = [1, 0, 15] + actual_padding = info.block_iter_paddings + assert actual_padding is not None + assert len(actual_padding) == len(expected_padding) + for actual, expected in zip(actual_padding, expected_padding): + assert actual == expected + + def check_index_map(workload, block_name, intrin_name, expected_index_map): s = Schedule(workload) block = s.get_block(block_name)