From f18009fa203917762443a0a5a8a135f6e4955a73 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 3 Feb 2023 16:37:52 -0800 Subject: [PATCH 01/15] [MetaSchedule] Tile and pack intermediate output for CUDA TensorCore --- src/meta_schedule/postproc/verify_gpu_code.cc | 2 + .../schedule_rule/multi_level_tiling.cc | 14 +- .../schedule_rule/multi_level_tiling.h | 3 +- .../multi_level_tiling_tensor_core.cc | 138 +++++++++++++++++- .../multi_level_tiling_wide_vector.cc | 15 +- .../manifest_shared_memory_local_stage.cc | 3 +- 6 files changed, 155 insertions(+), 20 deletions(-) diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 99ffc1bfcdf7..a19d2d0eeb14 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -162,6 +162,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); pass_list.push_back(tir::transform::UnifyThreadBinding()); + pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); @@ -189,6 +190,7 @@ class VerifyGPUCodeNode : public PostprocNode { IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); } catch (const dmlc::Error& e) { + LOG(INFO) << e.what(); return false; } if (!Verify(lowered)) { diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 779114e9cfea..8b7d613563f1 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -186,15 +186,15 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { return results; } -Array MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, - int n_tiles) const { +std::pair, Array> MultiLevelTilingNode::SplitLoop( + const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const { Array factors = sch->SamplePerfectTile( /*loop=*/loop, /*n=*/n_tiles, /*max_innermost_factor=*/max_innermost_factor); Array splits = sch->Split(/*loop=*/loop, /*factors=*/{factors.begin(), factors.end()}); - return splits; + return {factors, splits}; } std::vector MultiLevelTilingNode::TileLoopNest(State state) const { @@ -207,6 +207,9 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { // Step 2. For each loop axis, tile it int64_t spatial_loop_product = 1; std::vector> tiles(s_indices_.size() + r_indices_.size()); + state->tile_factors.resize(tiles.size()); + std::vector> tile_factors; + tile_factors.resize(tiles.size()); for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; const std::vector* idx = nullptr; @@ -231,14 +234,17 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { if (n_tiles == 1) { tiles[idx->at(0)].push_back(loop); } else { - auto splits = SplitLoop(sch, block_rv, loop, n_tiles); + auto [factors, splits] = SplitLoop(sch, block_rv, loop, n_tiles); // Put every tile to its slot for (int j = 0; j < n_tiles; ++j) { tiles[idx->at(j)].push_back(splits[j]); + tile_factors[idx->at(j)].push_back(factors[j]); + // Array& a=state->tile_size[idx->at(j)];//.push_back(factors[j]); } } } + state->tile_factors = std::move(tile_factors); // Step 3. Reorder to organize the tiles sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); // Step 4. Bind the tiles to threads diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index ff38756ff06b..1290491a1f75 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -98,6 +98,7 @@ class StateNode : public Object { std::unordered_map read_reuse; /*! \brief The mapping from buffer index to write cache block. */ std::unordered_map write_reuse; + Array> tile_factors; /*! * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that @@ -163,7 +164,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); - virtual Array SplitLoop(const tir::Schedule& sch, tir::BlockRV block, + virtual std::pair, Array> SplitLoop(const tir::Schedule& sch, tir::BlockRV block, tir::LoopRV loop, int n_tiles) const; // Annotate a block to use cooperative fetching diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index d5cca52d41f9..c1e7cadbbbfc 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -124,6 +125,9 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode { private: // SubRule: Add tensorization-related transformations inline std::vector TransformForTensorization(TensorCoreState state) const; + // Subrule: Transform the layout of the output. This is necessary for efficient cache write the + // output in the shared memory. + std::vector TransformIntermediateOutputLayout(TensorCoreState state); // Subrule: Add tensorized load inline std::vector AddReadReuseTensorCore(TensorCoreState state) const; // Subrule: Add tensorized store @@ -225,6 +229,9 @@ std::vector MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector(state)); }); states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { + return TransformIntermediateOutputLayout(Downcast(state)); + }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuseTensorCore(Downcast(state)); @@ -248,6 +255,88 @@ void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch, (*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name); } +std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLayout( + TensorCoreState state) { + // Get the shape of the wmma accumulator + tir::Block intrin_block = + Downcast( + tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body) + ->block; + tir::For loop_m = Downcast(intrin_block->body); + tir::For loop_n = Downcast(loop_m->body); + PrimExpr accumulator_m = loop_m->extent; + PrimExpr accumulator_n = loop_n->extent; + Schedule& sch = state->sch; + + auto buffer_ndim = sch->Get(state->block_rv)->writes[0]->buffer->shape.size(); + + // The dimension of the buffer should be larger or same as that of the tensor intrin. + ICHECK_GE(buffer_ndim, 2); + + auto index_map_pack_accumulator_tile = + tir::IndexMap::FromFunc(buffer_ndim, [&](const Array& indices) -> Array { + const auto& i = indices[buffer_ndim - 2]; + const auto& j = indices[buffer_ndim - 1]; + Array result; + for (int i = 0; i < buffer_ndim - 2; ++i) { + result.push_back(indices[i]); + } + result.push_back(floordiv(i, accumulator_m)); + result.push_back(floordiv(j, accumulator_n)); + result.push_back(floormod(i, accumulator_m)); + result.push_back(floormod(j, accumulator_n)); + return result; + }); + sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, + index_map_pack_accumulator_tile); + auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y"); + ICHECK(it != tile_binds.end()); + auto idx = std::distance(tile_binds.begin(), it); + auto f_get_tile_product = [&](int loop_idx) { + Array factors; + for (int i = idx + 1; i < s_indices_.size(); ++i) { + auto s_factors = state->tile_factors[s_indices_[i]]; + if (loop_idx < 0) { + loop_idx += s_factors.size(); + } + factors.push_back(s_factors[loop_idx]); + } + ICHECK(!factors.empty()); + if (factors.size() == 1) { + return factors[0]; + } + auto result = factors[0]; + for (int i = 1; i < factors.size(); ++i) { + result = result * factors[i]; + } + return result; + }; + auto warp_factor_m = f_get_tile_product(-2); + auto warp_factor_n = f_get_tile_product(-1); + auto index_map_pack_warp_tile = + tir::IndexMap::FromFunc(buffer_ndim + 2, [&](const Array& indices) -> Array { + const auto& i = indices[indices.size() - 4]; + const auto& j = indices[indices.size() - 3]; + const auto& m = indices[indices.size() - 2]; + const auto& n = indices[indices.size() - 1]; + Array result; + for (int i = 0; i < indices.size() - 4; ++i) { + result.push_back(indices[i]); + } + result.push_back(floordiv(i, warp_factor_m)); + result.push_back(floordiv(j, warp_factor_n)); + result.push_back(floormod(i, warp_factor_m)); + result.push_back(floormod(j, warp_factor_n)); + result.push_back(m); + result.push_back(n); + return result; + }); + sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map_pack_warp_tile); + sch->GetLoops(state->block_rv); + + return {state}; +} + std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( TensorCoreState state) const { // Add the cache write stage for Tensor Core @@ -255,18 +344,51 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( const LoopRV& loop = state->tiles[level].back(); Schedule& sch = state->sch; auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator"); - sch->ReverseComputeAt(cache_write, loop, true); if (state->write_reuse.count(0)) { - // Fuse the iterators of the cache_write - Array buffer_loops = sch->GetLoops(state->write_reuse[0]); - ICHECK_GT(buffer_loops.size(), 2); - sch->Fuse(Array{buffer_loops.end() - 2, // The src shmem is always 2D - buffer_loops.end()}); - AnnotateCooperativeFetching(&sch, state->write_reuse[0]); + auto f_get_loops = [&](const BlockRV& block_rv) -> std::tuple { + Array buffer_loops = sch->GetLoops(block_rv); + ICHECK_GT(buffer_loops.size(), 6); + return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], + buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]}; + }; + + { + const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]); + sch->Reorder({i1, i0, j0, j1}); + sch->ComputeAt(cache_write, i1, true); + } + { + const auto& [i0, j0, i1, j1] = f_get_loops(cache_write); + auto fused = sch->Fuse({i0, j0}); + sch->Bind(fused, "threadIdx.y"); + } } + // sch->ReverseComputeAt(cache_write, loop, true); + + // if (state->write_reuse.count(0)) { + // // Fuse the iterators of the cache_write + // Array buffer_loops = sch->GetLoops(state->write_reuse[0]); + // ICHECK_GT(buffer_loops.size(), 2); + // sch->Fuse(Array{buffer_loops.end() - 2, // The src shmem is always 2D + // buffer_loops.end()}); + // AnnotateCooperativeFetching(&sch, state->write_reuse[0]); + // } sch->ReverseComputeInline(state->tensor_core_reindex_store); - TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin); + auto loops = sch->GetLoops(cache_write); + auto blockized_store = sch->Blockize(loops[loops.size() - 2]); + sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize, + state->intrin_group.store_intrin); + + Array buffer_loops = sch->GetLoops(state->write_reuse[0]); + ICHECK_GT(buffer_loops.size(), 5); + sch->Fuse(Array{buffer_loops.end() - 5, // The src shmem is always 2D + buffer_loops.end()}); + AnnotateCooperativeFetching(&sch, state->write_reuse[0]); + + // + // + // TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin); return {state}; } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index d4c4a10fdd72..e68b64ea2d3a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -48,11 +48,12 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { return ScheduleRule(n); } - Array SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const; + std::pair, Array> SplitLoop(const Schedule& sch, BlockRV block, + LoopRV loop, int n_tiles) const; }; -Array MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv, - LoopRV loop_rv, int n_tiles) const { +std::pair, Array> MultiLevelTilingWideVectorNode::SplitLoop( + const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const { const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); const tir::StmtSRef block_sref = sch->GetSRef(block_rv); const tir::BlockNode* block_node = block_sref->StmtAs(); @@ -99,12 +100,14 @@ Array MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch Array outer_splits = sch->Split( /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()}); outer_splits.push_back(inner_splits[1]); - return outer_splits; + outer_factors.push_back(PrimExpr(vec_len)); + return {outer_factors, outer_splits}; } else { Array factors(n_tiles - 1, PrimExpr(1)); factors.push_back(loop->extent); - return sch->Split(/*loop=*/loop_rv, - /*factors=*/{factors.begin(), factors.end()}); + Array splits = sch->Split(/*loop=*/loop_rv, + /*factors=*/{factors.begin(), factors.end()}); + return {factors, splits}; } } } diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 0f56c8b8b7c9..5ddc4526e854 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -53,7 +53,8 @@ class IntermediateStageRewriter { const BufferStoreNode* store = block->body.as(); CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank == runtime::StorageRank::kShared) - << "ValueError: Expect the body of the block to be BufferStore to shared memory."; + << "ValueError: Expect the body of the block to be BufferStore to shared memory." + << "But get " << block->body; const Buffer& target_buffer = store->buffer; From a6618351c6996a250947af08fdf306603a2f600e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 22 Feb 2023 14:28:30 -0800 Subject: [PATCH 02/15] clean up schedule rule mltc --- .../multi_level_tiling_tensor_core.cc | 239 +++++++++++------- 1 file changed, 147 insertions(+), 92 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index c1e7cadbbbfc..e1d6d0651d9b 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -25,7 +25,6 @@ #include "../utils.h" #include "./multi_level_tiling.h" - namespace tvm { namespace meta_schedule { @@ -257,44 +256,36 @@ void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch, std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLayout( TensorCoreState state) { - // Get the shape of the wmma accumulator - tir::Block intrin_block = - Downcast( - tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body) - ->block; - tir::For loop_m = Downcast(intrin_block->body); - tir::For loop_n = Downcast(loop_m->body); - PrimExpr accumulator_m = loop_m->extent; - PrimExpr accumulator_n = loop_n->extent; - Schedule& sch = state->sch; - - auto buffer_ndim = sch->Get(state->block_rv)->writes[0]->buffer->shape.size(); + // Transform the intermediate output to packed layout + // [..., warp_m, warp_n, accum_frag_m, accum_frag_n, accum_elem_m, accum_elem_n] + // where warp_m, warp_n are thread indices bound to the warp id, accum_frag_m, accum_frag_n are + // the index of the fragments in each warp, accum_elem_m, accum_elem_n are the index of the + // elements in each accumulator fragment. - // The dimension of the buffer should be larger or same as that of the tensor intrin. - ICHECK_GE(buffer_ndim, 2); - - auto index_map_pack_accumulator_tile = - tir::IndexMap::FromFunc(buffer_ndim, [&](const Array& indices) -> Array { - const auto& i = indices[buffer_ndim - 2]; - const auto& j = indices[buffer_ndim - 1]; - Array result; - for (int i = 0; i < buffer_ndim - 2; ++i) { - result.push_back(indices[i]); - } - result.push_back(floordiv(i, accumulator_m)); - result.push_back(floordiv(j, accumulator_n)); - result.push_back(floormod(i, accumulator_m)); - result.push_back(floormod(j, accumulator_n)); - return result; - }); - sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, - index_map_pack_accumulator_tile); + // Get the shape of the wmma accumulator + auto [frag_shape_m, frag_shape_n] = [&]() { + tir::Block intrin_block = + Downcast( + tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body) + ->block; + tir::For loop_m = Downcast(intrin_block->body); + tir::For loop_n = Downcast(loop_m->body); + return std::make_tuple(loop_m->extent, loop_n->extent); + }(); + + // Get the tile index of the warp id (i.e. threadIdx.y) auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y"); ICHECK(it != tile_binds.end()); - auto idx = std::distance(tile_binds.begin(), it); - auto f_get_tile_product = [&](int loop_idx) { + auto tile_index_warp_id = std::distance(tile_binds.begin(), it); + + // Get the extent of loop indicated by `loop_idx` inside the warp scope. + // For example, after spatial loops i, j are tiled, we will have + // tile_factors = ((i0, j0), (i1, j1), ..., (in, jn)) + // This function computes the product of tile_factors[i][loop_idx] for i > tile_index_warp_id. + // `loop_idx` can be negative, in which case it is counted from the end. + auto f_get_inner_tile_product = [&](int loop_idx) { Array factors; - for (int i = idx + 1; i < s_indices_.size(); ++i) { + for (int i = tile_index_warp_id + 1; i < static_cast(s_indices_.size()); ++i) { auto s_factors = state->tile_factors[s_indices_[i]]; if (loop_idx < 0) { loop_idx += s_factors.size(); @@ -306,33 +297,82 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa return factors[0]; } auto result = factors[0]; - for (int i = 1; i < factors.size(); ++i) { + for (int i = 1; i < static_cast(factors.size()); ++i) { result = result * factors[i]; } return result; }; - auto warp_factor_m = f_get_tile_product(-2); - auto warp_factor_n = f_get_tile_product(-1); - auto index_map_pack_warp_tile = - tir::IndexMap::FromFunc(buffer_ndim + 2, [&](const Array& indices) -> Array { - const auto& i = indices[indices.size() - 4]; - const auto& j = indices[indices.size() - 3]; - const auto& m = indices[indices.size() - 2]; - const auto& n = indices[indices.size() - 1]; - Array result; - for (int i = 0; i < indices.size() - 4; ++i) { - result.push_back(indices[i]); - } - result.push_back(floordiv(i, warp_factor_m)); - result.push_back(floordiv(j, warp_factor_n)); - result.push_back(floormod(i, warp_factor_m)); - result.push_back(floormod(j, warp_factor_n)); - result.push_back(m); - result.push_back(n); - return result; - }); - sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map_pack_warp_tile); - sch->GetLoops(state->block_rv); + + // Compute the number of output fragment of each warp + auto warp_num_frag_m = f_get_inner_tile_product(-2); + auto warp_num_frag_n = f_get_inner_tile_product(-1); + + Schedule& sch = state->sch; + int buffer_ndim = static_cast(sch->Get(state->block_rv)->writes[0]->buffer->shape.size()); + // The dimension of the buffer should be larger or same as that of the tensor intrin. + ICHECK_GE(buffer_ndim, 2); + int num_higher_dims = buffer_ndim - 2; + + sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, + tir::IndexMap::FromFunc(buffer_ndim, [&](const Array& indices) { + Array result; + result.reserve(indices.size() + 4); + for (int i = 0; i < num_higher_dims; ++i) { + result.push_back(indices[i]); + } + const auto& m = indices[num_higher_dims]; + const auto& n = indices[num_higher_dims + 1]; + auto accum_m = floormod(m, frag_shape_m); + auto accum_n = floormod(n, frag_shape_n); + auto outer_m = floordiv(m, frag_shape_m); + auto outer_n = floordiv(n, frag_shape_n); + + result.push_back(floordiv(outer_m, warp_num_frag_m)); + result.push_back(floordiv(outer_n, warp_num_frag_n)); + result.push_back(floormod(outer_m, warp_num_frag_m)); + result.push_back(floormod(outer_n, warp_num_frag_n)); + result.push_back(accum_m); + result.push_back(accum_n); + return result; + })); + + // // Tile by the fragment shape + // sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, + // tir::IndexMap::FromFunc(buffer_ndim, [&](const Array& indices) { + // Array result; + // result.reserve(indices.size() + 2); + // for (int i = 0; i < num_higher_dims; ++i) { + // result.push_back(indices[i]); + // } + // const auto& m = indices[num_higher_dims]; + // const auto& n = indices[num_higher_dims + 1]; + // result.push_back(floordiv(m, frag_shape_m)); + // result.push_back(floordiv(n, frag_shape_n)); + // result.push_back(floormod(m, frag_shape_m)); + // result.push_back(floormod(n, frag_shape_n)); + // return result; + // })); + + // // Tile by the number of fragments + // sch->TransformLayout( + // state->block_rv, 0, tir::BufferIndexType::kWrite, + // tir::IndexMap::FromFunc(buffer_ndim + 2, [&](const Array& indices) { + // Array result; + // result.reserve(indices.size() + 2); + // for (int i = 0; i < num_higher_dims; ++i) { + // result.push_back(indices[i]); + // } + // const auto& m = indices[num_higher_dims]; + // const auto& n = indices[num_higher_dims + 1]; + // result.push_back(floordiv(m, warp_num_frag_m)); + // result.push_back(floordiv(n, warp_num_frag_n)); + // result.push_back(floormod(m, warp_num_frag_m)); + // result.push_back(floormod(n, warp_num_frag_n)); + // // The last two indices are the fragment element indices + // result.push_back(indices[num_higher_dims + 2]); + // result.push_back(indices[num_higher_dims + 3]); + // return result; + // })); return {state}; } @@ -340,40 +380,55 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( TensorCoreState state) const { // Add the cache write stage for Tensor Core - int level = r_indices_.front() - 1; - const LoopRV& loop = state->tiles[level].back(); Schedule& sch = state->sch; auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator"); - if (state->write_reuse.count(0)) { - auto f_get_loops = [&](const BlockRV& block_rv) -> std::tuple { - Array buffer_loops = sch->GetLoops(block_rv); - ICHECK_GT(buffer_loops.size(), 6); - return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], - buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]}; - }; - - { - const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]); - sch->Reorder({i1, i0, j0, j1}); - sch->ComputeAt(cache_write, i1, true); - } - { - const auto& [i0, j0, i1, j1] = f_get_loops(cache_write); - auto fused = sch->Fuse({i0, j0}); - sch->Bind(fused, "threadIdx.y"); - } + // The compute block has been tiled by the warp shape and the fragment shape. + // We need to bind the cache write block (from the accumulator to the shared memory) to the warp + // id. The schedule is as follows: + // + // After adding cache write for wmma.accumulator, we will have + // for i0, j0, i1, j1, accum_m, accum_n: + // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n] + // for i0', j0', i1', j1', accum_m', accum_n': + // global_mem[i0', j0', i1', j1', accum_m', accum_n'] = + // shared_mem[i0', j0', i1', j1', accum_m', accum_n'] + // where i0' and j0' are already bound to the block id and warp id. + // + // To reduce the shared memory usage and allow efficient data movement, we will apply + // transformations to generate the following schedule: + // + // for i1': + // for i0_j0 (fused and bound to threadIdx.y): + // for j1, accum_m, accum_n: + // shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n] + // for i0', j0', j1', accum_m', accum_n': + // global_mem[i0', j0', i1', j1', accum_m', accum_n'] = + // shared_mem[i0', j0', i1', j1', accum_m', accum_n'] + // + // i1' is reordered to the outermost. This effectively allows only a row (i.e. loop i1') of the + // fragments are moved to the shared memory and then to the global memory each time. + // As a result, shared memory for the output will only have shape of [j1, accum_m, accum_n] + // instead of [i0 * i1 * accum_m, j0 * j1 * accum_n]. + + // Get the loops other than the innermost two loops (accum_m and accum_n). + auto f_get_loops = [&](const BlockRV& block_rv) -> std::array { + Array buffer_loops = sch->GetLoops(block_rv); + ICHECK_GT(buffer_loops.size(), 6); + return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5], + buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]}; + }; + { + const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]); + sch->Reorder({i1, i0, j0, j1}); + sch->ComputeAt(cache_write, i1, true); + } + { + const auto& [i0, j0, i1, j1] = f_get_loops(cache_write); + auto fused = sch->Fuse({i0, j0}); + sch->Bind(fused, "threadIdx.y"); } - // sch->ReverseComputeAt(cache_write, loop, true); - - // if (state->write_reuse.count(0)) { - // // Fuse the iterators of the cache_write - // Array buffer_loops = sch->GetLoops(state->write_reuse[0]); - // ICHECK_GT(buffer_loops.size(), 2); - // sch->Fuse(Array{buffer_loops.end() - 2, // The src shmem is always 2D - // buffer_loops.end()}); - // AnnotateCooperativeFetching(&sch, state->write_reuse[0]); - // } + sch->ReverseComputeInline(state->tensor_core_reindex_store); auto loops = sch->GetLoops(cache_write); auto blockized_store = sch->Blockize(loops[loops.size() - 2]); @@ -385,10 +440,6 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( sch->Fuse(Array{buffer_loops.end() - 5, // The src shmem is always 2D buffer_loops.end()}); AnnotateCooperativeFetching(&sch, state->write_reuse[0]); - - // - // - // TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin); return {state}; } @@ -691,6 +742,10 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + CHECK(node->reuse_write_.req == ReuseType::kMustReuse && + runtime::StorageScope::Create(node->reuse_write_.scope).rank == runtime::StorageRank::kShared) + << "ValueError: Shared memory write reuse must be enabled for MultiLevelTilingTensorCore."; + node->intrin_groups.reserve(intrin_groups.size()); for (const auto& intrin_group_config : intrin_groups) { node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config)); From 596b83ffddb77b251e389e04901325c6104f693c Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 22 Feb 2023 14:28:58 -0800 Subject: [PATCH 03/15] add lhs analyzer --- src/tir/schedule/ir_comparator.cc | 16 +++++++++++----- src/tir/schedule/ir_comparator.h | 6 +++++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 9d89c641630b..d93087fc92fc 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -43,7 +43,7 @@ class TensorIntrinMismatchError : public ScheduleError { std::ostringstream os; os << "The stmt {0} doesn't match the tensor intrin\nThe pattern attempting to be matched:\n" << lhs_stmt_ << "\nDoes not match the tensorize description:\n" - << rhs_stmt_; + << rhs_stmt_ << '\n'; for (const auto& msg : error_messages_) { os << msg << std::endl; } @@ -173,6 +173,9 @@ bool TensorizeComparator::VisitStmt_(const BlockRealizeNode* op, const Stmt& oth bool TensorizeComparator::VisitStmt_(const BlockNode* op, const Stmt& other) { const auto* rhs = other.as(); + for (const IterVar& iter : op->iter_vars) { + lhs_analyzer_.Bind(iter->var, iter->dom); + } // Check block equality. // All iter vars and buffer regions including the order should match. // When checking iter vars, DefEqual is used to remap variables. @@ -313,7 +316,7 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { equal_map_[lhs] = rhs; // Cast if necessary. This allows the workload and the tensor intrin to have different dtypes in // the indices. - analyzer_.Bind(lhs, cast(lhs.dtype(), rhs)); + analyzer_.Bind(lhs, cast(lhs.dtype(), rhs), /*allow_override=*/true); return true; } @@ -433,7 +436,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - indices_base.emplace_back(lhs->region[i]->min); + indices_base.emplace_back(lhs_analyzer_.Simplify(lhs->region[i]->min)); } for (size_t i = 0; i < rhs->region.size(); i++) { // save base index @@ -465,12 +468,15 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - if (!analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) { + if (!lhs_analyzer_.CanProveEqual(indices_base[i], lhs->region[i]->min)) { if (assert_mode_) { std::ostringstream os; os << "Buffer base index consistency check failed due to unequal index base: " "indices_base[i]=" << indices_base[i] << " vs lhs->region[i]->min=" << lhs->region[i]->min; + os << "\ni=" << i << ", offset=" << offset << ", lhs->region.qsize()=" + << lhs->region.size() << ", rhs->region.size()=" << rhs->region.size() << lhs->region << rhs->region; + os << "\nTrying simplify: " << analyzer_.Simplify(lhs->region[i]->min); EmitError(os.str()); } return false; @@ -487,7 +493,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - PrimExpr normalized_lhs_min = (lhs->region[i + offset]->min - indices_base[i + offset]); + PrimExpr normalized_lhs_min = lhs_analyzer_.Simplify((lhs->region[i + offset]->min - indices_base[i + offset])); if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { if (assert_mode_) { std::ostringstream os; diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index 394d82867393..b37fb6654670 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -102,8 +102,12 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool assert_mode_; /*! \brief Whether it is visiting the scope block (the outermost block). */ bool is_scope_block = true; - /*! \brief The arithmetic analyzer. */ + /*! \brief The arithmetic analyzer for comparing LHS and RHS */ arith::Analyzer analyzer_; + /*! \brief The arithmetic analyzer for simplifying expressions on LHS. + * This analyzer only contains the domains of the iterators on LHS. + */ + arith::Analyzer lhs_analyzer_; /*! \brief Additional error messages. Only used when assert_mode is true. */ std::vector error_messages_; // variable remap if any From d85b771ec6c0bdd5f5370a8d5f146c535e47acc0 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 22 Feb 2023 14:43:31 -0800 Subject: [PATCH 04/15] prevent simplifying single point --- src/tir/analysis/block_access_region_detector.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index e9bff1b6fdee..ab328efaa6d1 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -76,8 +76,6 @@ class BlockReadWriteDetector : public StmtExprVisitor { Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; - /*! \brief The analyzer for simplifying*/ - arith::Analyzer analyzer_; /*! * \brief Update read/write buffers and regions with provided buffer and region @@ -330,7 +328,12 @@ Array BlockReadWriteDetector::CollectRegions( ICHECK_EQ(buffers[i]->shape.size(), regions[i].size()); for (size_t j = 0; j < regions[i].size(); j++) { const tvm::arith::IntSet& range = regions[i][j]; - region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); + if (range.IsSinglePoint()) { + PrimExpr min = range.min(); + region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); + } else { + region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); + } } res.push_back(BufferRegion(buffers[i], region)); } From 270dd5e815bf1ca3fe7859a36d2302f00674f6dd Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 23 Feb 2023 14:51:03 -0800 Subject: [PATCH 05/15] clean up --- src/meta_schedule/postproc/verify_gpu_code.cc | 1 - .../multi_level_tiling_tensor_core.cc | 46 ++++++++++--------- src/tir/schedule/ir_comparator.cc | 8 ++-- .../manifest_shared_memory_local_stage.cc | 3 +- 4 files changed, 28 insertions(+), 30 deletions(-) diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index a19d2d0eeb14..6f9b46a0f734 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -190,7 +190,6 @@ class VerifyGPUCodeNode : public PostprocNode { IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); } catch (const dmlc::Error& e) { - LOG(INFO) << e.what(); return false; } if (!Verify(lowered)) { diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index e1d6d0651d9b..27554f5985d8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -25,6 +25,7 @@ #include "../utils.h" #include "./multi_level_tiling.h" + namespace tvm { namespace meta_schedule { @@ -314,27 +315,27 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa int num_higher_dims = buffer_ndim - 2; sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, - tir::IndexMap::FromFunc(buffer_ndim, [&](const Array& indices) { - Array result; - result.reserve(indices.size() + 4); - for (int i = 0; i < num_higher_dims; ++i) { - result.push_back(indices[i]); - } - const auto& m = indices[num_higher_dims]; - const auto& n = indices[num_higher_dims + 1]; - auto accum_m = floormod(m, frag_shape_m); - auto accum_n = floormod(n, frag_shape_n); - auto outer_m = floordiv(m, frag_shape_m); - auto outer_n = floordiv(n, frag_shape_n); - - result.push_back(floordiv(outer_m, warp_num_frag_m)); - result.push_back(floordiv(outer_n, warp_num_frag_n)); - result.push_back(floormod(outer_m, warp_num_frag_m)); - result.push_back(floormod(outer_n, warp_num_frag_n)); - result.push_back(accum_m); - result.push_back(accum_n); - return result; - })); + tir::IndexMap::FromFunc(buffer_ndim, [&](const Array& indices) { + Array result; + result.reserve(indices.size() + 4); + for (int i = 0; i < num_higher_dims; ++i) { + result.push_back(indices[i]); + } + const auto& m = indices[num_higher_dims]; + const auto& n = indices[num_higher_dims + 1]; + auto accum_m = floormod(m, frag_shape_m); + auto accum_n = floormod(n, frag_shape_n); + auto outer_m = floordiv(m, frag_shape_m); + auto outer_n = floordiv(n, frag_shape_n); + + result.push_back(floordiv(outer_m, warp_num_frag_m)); + result.push_back(floordiv(outer_n, warp_num_frag_n)); + result.push_back(floormod(outer_m, warp_num_frag_m)); + result.push_back(floormod(outer_n, warp_num_frag_n)); + result.push_back(accum_m); + result.push_back(accum_n); + return result; + })); // // Tile by the fragment shape // sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, @@ -743,7 +744,8 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); CHECK(node->reuse_write_.req == ReuseType::kMustReuse && - runtime::StorageScope::Create(node->reuse_write_.scope).rank == runtime::StorageRank::kShared) + runtime::StorageScope::Create(node->reuse_write_.scope).rank == + runtime::StorageRank::kShared) << "ValueError: Shared memory write reuse must be enabled for MultiLevelTilingTensorCore."; node->intrin_groups.reserve(intrin_groups.size()); diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index d93087fc92fc..045f45aa6634 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -316,7 +316,7 @@ bool TensorizeComparator::DefEqual(const Var& lhs, const Var& rhs) { equal_map_[lhs] = rhs; // Cast if necessary. This allows the workload and the tensor intrin to have different dtypes in // the indices. - analyzer_.Bind(lhs, cast(lhs.dtype(), rhs), /*allow_override=*/true); + analyzer_.Bind(lhs, cast(lhs.dtype(), rhs)); return true; } @@ -474,9 +474,6 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf os << "Buffer base index consistency check failed due to unequal index base: " "indices_base[i]=" << indices_base[i] << " vs lhs->region[i]->min=" << lhs->region[i]->min; - os << "\ni=" << i << ", offset=" << offset << ", lhs->region.qsize()=" - << lhs->region.size() << ", rhs->region.size()=" << rhs->region.size() << lhs->region << rhs->region; - os << "\nTrying simplify: " << analyzer_.Simplify(lhs->region[i]->min); EmitError(os.str()); } return false; @@ -493,7 +490,8 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - PrimExpr normalized_lhs_min = lhs_analyzer_.Simplify((lhs->region[i + offset]->min - indices_base[i + offset])); + PrimExpr normalized_lhs_min = + lhs_analyzer_.Simplify((lhs->region[i + offset]->min - indices_base[i + offset])); if (!analyzer_.CanProveEqual(normalized_lhs_min, rhs->region[i]->min)) { if (assert_mode_) { std::ostringstream os; diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 5ddc4526e854..0f56c8b8b7c9 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -53,8 +53,7 @@ class IntermediateStageRewriter { const BufferStoreNode* store = block->body.as(); CHECK(store != nullptr && runtime::StorageScope::Create(store->buffer.scope()).rank == runtime::StorageRank::kShared) - << "ValueError: Expect the body of the block to be BufferStore to shared memory." - << "But get " << block->body; + << "ValueError: Expect the body of the block to be BufferStore to shared memory."; const Buffer& target_buffer = store->buffer; From 86d4002a1197f69b2688c957bdc28201562f1616 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 23 Feb 2023 15:26:29 -0800 Subject: [PATCH 06/15] lint --- src/meta_schedule/schedule_rule/multi_level_tiling.h | 6 ++++-- .../schedule_rule/multi_level_tiling_tensor_core.cc | 4 +++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 1290491a1f75..dc679b52c50e 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -164,8 +164,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); - virtual std::pair, Array> SplitLoop(const tir::Schedule& sch, tir::BlockRV block, - tir::LoopRV loop, int n_tiles) const; + virtual std::pair, Array> SplitLoop(const tir::Schedule& sch, + tir::BlockRV block, + tir::LoopRV loop, + int n_tiles) const; // Annotate a block to use cooperative fetching void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 27554f5985d8..edf538fbfd81 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -425,7 +425,9 @@ std::vector MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore( sch->ComputeAt(cache_write, i1, true); } { - const auto& [i0, j0, i1, j1] = f_get_loops(cache_write); + auto loops = f_get_loops(cache_write); + const auto& i0 = loops[0]; + const auto& j0 = loops[1]; auto fused = sch->Fuse({i0, j0}); sch->Bind(fused, "threadIdx.y"); } From 8f4f32b73eafe6c12d064ca6490132461bc1f138 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 23 Feb 2023 16:04:04 -0800 Subject: [PATCH 07/15] fix rewrite_tensorize test --- src/tir/schedule/ir_comparator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 045f45aa6634..5353a051a60a 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -436,7 +436,7 @@ bool TensorizeComparator::CompareBufferRegion(const BufferRegion& lhs, const Buf } return false; } - indices_base.emplace_back(lhs_analyzer_.Simplify(lhs->region[i]->min)); + indices_base.emplace_back(lhs->region[i]->min); } for (size_t i = 0; i < rhs->region.size(); i++) { // save base index From f642ff2664729c0ddb22c471db81630a969ff3d0 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 23 Feb 2023 16:12:17 -0800 Subject: [PATCH 08/15] fix software pipeline test --- ...est_tir_transform_inject_software_pipeline.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 1e5fd8843ba3..b9f35ed553e1 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -1124,7 +1124,7 @@ def ref(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")): B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") with T.block(): T.reads(A[tx, 0]) - T.writes(B[0, tx, 0]) + T.writes(B[T.FloorMod(0, 2), tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[T.FloorMod(0, 2), tx, 0] = A[tx, 0] * T.float32(2) @@ -1350,8 +1350,8 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N B[i % 2, tx, 0] = A[tx, i] * T.float32(2) with T.block(): T.where(i == 1 and i - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[(i - 1) % 2, tx, 0]) + T.writes(C[(i - 1) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 1): @@ -1366,14 +1366,14 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N with T.block(): T.where(i + 2 < 16) T.reads(A[tx, i + 2]) - T.writes(B[i % 2, tx, 0]) + T.writes(B[(i + 2) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 0): with T.attr(0, "async_scope", 1): B[(i + 2) % 2, tx, 0] = A[tx, i + 2] * T.float32(2) with T.block(): T.where(i + 2 - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[(i - 1 + 2) % 2, tx, 0]) + T.writes(C[(i - 1 + 2) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 1): @@ -1394,8 +1394,8 @@ def ref(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")) -> N for i in T.unroll(2): with T.block(): T.where(i + 16 - 1 < 16) - T.reads(B[(i + 1) % 2, tx, 0]) - T.writes(C[(i + 1) % 2, tx, 0]) + T.reads(B[(i - 1 + 16) % 2, tx, 0]) + T.writes(C[(i - 1 + 16) % 2, tx, 0]) with T.attr(0, "async_commit_queue_scope", 1): with T.attr(0, "async_wait_queue_scope", 0): with T.attr(0, "async_wait_inflight_count", 0 - i): From 088e514db37829d00b74580fe4703c2406086eba Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 23 Feb 2023 19:40:34 -0800 Subject: [PATCH 09/15] fix compile on mac --- .../multi_level_tiling_tensor_core.cc | 49 ++++++++++--------- 1 file changed, 27 insertions(+), 22 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index edf538fbfd81..8d6c242a9bcd 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -314,28 +314,33 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa ICHECK_GE(buffer_ndim, 2); int num_higher_dims = buffer_ndim - 2; - sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, - tir::IndexMap::FromFunc(buffer_ndim, [&](const Array& indices) { - Array result; - result.reserve(indices.size() + 4); - for (int i = 0; i < num_higher_dims; ++i) { - result.push_back(indices[i]); - } - const auto& m = indices[num_higher_dims]; - const auto& n = indices[num_higher_dims + 1]; - auto accum_m = floormod(m, frag_shape_m); - auto accum_n = floormod(n, frag_shape_n); - auto outer_m = floordiv(m, frag_shape_m); - auto outer_n = floordiv(n, frag_shape_n); - - result.push_back(floordiv(outer_m, warp_num_frag_m)); - result.push_back(floordiv(outer_n, warp_num_frag_n)); - result.push_back(floormod(outer_m, warp_num_frag_m)); - result.push_back(floormod(outer_n, warp_num_frag_n)); - result.push_back(accum_m); - result.push_back(accum_n); - return result; - })); + auto index_map = + tir::IndexMap::FromFunc(buffer_ndim, + // frag_shape_m and frag_shape_n are structural bindings that cannot + // not be automatically captured until c++20 + [&, frag_shape_m = frag_shape_m, + frag_shape_n = frag_shape_n](const Array& indices) { + Array result; + result.reserve(indices.size() + 4); + for (int i = 0; i < num_higher_dims; ++i) { + result.push_back(indices[i]); + } + const auto& m = indices[num_higher_dims]; + const auto& n = indices[num_higher_dims + 1]; + auto accum_m = floormod(m, frag_shape_m); + auto accum_n = floormod(n, frag_shape_n); + auto outer_m = floordiv(m, frag_shape_m); + auto outer_n = floordiv(n, frag_shape_n); + + result.push_back(floordiv(outer_m, warp_num_frag_m)); + result.push_back(floordiv(outer_n, warp_num_frag_n)); + result.push_back(floormod(outer_m, warp_num_frag_m)); + result.push_back(floormod(outer_n, warp_num_frag_n)); + result.push_back(accum_m); + result.push_back(accum_n); + return result; + }); + sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map); // // Tile by the fragment shape // sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, From ad27f8210c40275df56766480f71aa446b7cdfa3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 3 Mar 2023 14:59:11 -0800 Subject: [PATCH 10/15] fix test cases --- ...test_meta_schedule_schedule_rule_mlt_tc.py | 783 ++++++++---------- 1 file changed, 351 insertions(+), 432 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index 9b869b4436c0..1cab2554e88f 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -83,39 +83,39 @@ def test_matmul_relu(shared_scope): @T.prim_func def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope=shared_scope) - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope=shared_scope) + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(4096): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4096): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(4096): + for ax0_ax1_fused in range(4096): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(4): + for ax2_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 2): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): @@ -127,8 +127,8 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): @@ -141,44 +141,54 @@ def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "f v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(1024): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) - T.reads(C_reindex_shared[v0, v1]) - T.writes(compute[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(1, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ @@ -223,44 +233,42 @@ def test_matmul_relu_with_fallback(): # fmt: off @T.prim_func def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: - # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + # with T.block("root"): + C_reindex_shared = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 2, 2, 4, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(2): - for ax0_ax1_fused in T.serial(2048): + for ax2_0_0 in range(2): + for ax0_ax1_fused in range(2048): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64) v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":4}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 4}) A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(8192): + for ax0_ax1_fused in range(8192): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(1): + for ax2_0_1 in range(1): for ax0_0, ax1_0 in T.grid(2, 4): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -271,9 +279,9 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -285,44 +293,54 @@ def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0_3 * 4 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 4 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 4): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(4096): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 128) - v1 = T.axis.spatial(128, ax0_ax1_fused % 128) - T.reads(C_reindex_shared[v0, v1]) - T.writes(compute[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 4, v0_o % 2, v1_o % 4, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 4): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) + v1 = T.axis.spatial(2, ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(4, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(2048): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused * 2 + ax0_0_1_ax1_0_1_fused) + v1 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 1024) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(4, ax0_ax1_ax3_ax4_ax5_fused % 1024 // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 64] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) # fmt: on decision_0 = [ ("SamplePerfectTile", [2, 2, 1, 1, 2]), @@ -373,46 +391,46 @@ def test_conv2d(shared_scope): @T.prim_func def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16") - conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope=shared_scope) - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 32], dtype="float32", scope="wmma.accumulator") - PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope=shared_scope) - weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope=shared_scope) - PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], dtype="float16", scope="wmma.matrix_a") - weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], dtype="float16", scope="wmma.matrix_b") + PadInput = T.alloc_buffer((1, 18, 18, 32), "float16") + conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope=shared_scope) + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 2, 1, 1, 16, 16), scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer((256, 288), "float16", scope=shared_scope) + weight_reindex_shared = T.alloc_buffer((288, 32), "float16", scope=shared_scope) + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 288), "float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((288, 32), "float16", scope="wmma.matrix_b") for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): with T.block("PadInput"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) - T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16") + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.float16(0)) for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(4608): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4608): with T.block("PadInput_reindex_shared"): v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 288) v1 = T.axis.spatial(288, ax0_ax1_fused % 288) T.reads(PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 2}) PadInput_reindex_shared[v0, v1] = PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] - for ax0_ax1_fused in T.serial(4608): + for ax0_ax1_fused in range(4608): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(288, ax0_ax1_fused // 16) v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) T.writes(weight_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] - for ax2_0_1 in T.serial(18): + for ax2_0_1 in range(18): for ax0_0, ax1_0 in T.grid(1, 1): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) v1_o = T.axis.spatial(18, ax2_0_1 + ax1_0) - T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("PadInput_reindex_shared_wmma.matrix_a"): @@ -424,8 +442,8 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, with T.block("weight_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(18, ax2_0_1 + ax0_0) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) - T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(weight_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("weight_reindex_shared_wmma.matrix_b"): @@ -438,44 +456,49 @@ def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, v0_o = T.axis.spatial(16, ax0_0_4 + ax0_0_1_ax1_0_1_fused + ax0_0_3) v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0_3 + ax1_0_4) v2_o = T.axis.reduce(18, ax2_0_0 * 18 + ax2_0_1 + ax2_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(16, ax0_0_1_ax1_0_1_fused + ax0_0) - v1_o = T.axis.spatial(2, ax0_0_0_ax1_0_0_fused + ax1_0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(256): - with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 16) - v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16) - T.reads(conv2d_nhwc_reindex_shared[v0, v1]) - T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) - conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i], PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o, v1_o, 0, 0, v0_i, v1_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(1): + for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2_1, ax3]) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(256): + with T.block("conv2d_nhwc_reindex_shared"): + v0, v1, v2 = T.axis.remap("SSS", [ax0_0_1_ax1_0_1_fused, ax0_0_0_ax1_0_0_fused, ax2]) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 16, 1, 1, 1]), @@ -551,40 +574,40 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") - C = T.alloc_buffer([128, 128], dtype="float32") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope=shared_scope) - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope=shared_scope) - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + C = T.alloc_buffer((128, 128)) + C_reindex_shared = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope=shared_scope) + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 4, 2, 2, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope=shared_scope) + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): - for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order":[0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage":[0, 0, 0, 0, 0, 1, 1]}): - for ax0_ax1_fused in T.serial(1024): + for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): + for ax0_ax1_fused in range(1024): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused % 32) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":4, "tir.manifest_shared_memory_local_stage":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4, "tir.manifest_shared_memory_local_stage": 1}) A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(1024): + for ax0_ax1_fused in range(1024): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":2, "tir.manifest_shared_memory_local_stage":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2, "tir.manifest_shared_memory_local_stage": 1}) B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}): + for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0, ax1_0 in T.grid(2, 1): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_a_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): @@ -596,8 +619,8 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) T.block_attr({"meta_schedule.auto_tensorize": f"wmma_load_16x16x16_f16_b_{intrin_suffix}"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): @@ -610,50 +633,61 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0_3 * 2 + ax1_0_4) v2_o = T.axis.reduce(8, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.grid(1024): - with T.block("C_reindex_shared"): - v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) - T.reads(C_reindex_shared[v0, v1]) - T.writes(C[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":3}) - C[v0, v1] = C_reindex_shared[v0, v1] + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o // 2, v0_o % 2, v1_o % 2, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 2): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) + v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(2, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": f"wmma_store_16x16x16_f32_{intrin_suffix}"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused // 4) + v1 = T.axis.spatial(4, ax0_0_1_ax1_0_1_fused % 4) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(2, ax0_ax1_ax3_ax4_ax5_fused // 256) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32]) + T.block_attr({"meta_schedule.cooperative_fetch": 3}) + C[v4 + v2 * 16 + v0 * 32, v5 + v3 * 16 + v1 * 32] = C_reindex_shared[v0, v1, v2, v3, v4, v5] for i0, i1 in T.grid(128, 128): with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(C[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(C[v_i0, v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.max(C[v_i0, v_i1], T.float32(0)) + # fmt: on decision_0 = [ ("SamplePerfectTile", [1, 4, 1, 1, 2]), @@ -693,141 +727,6 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, ) -def test_matmul_relu_global(): - # fmt: off - @T.prim_func - def matmul_relu_global_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - C = T.alloc_buffer([128, 128], dtype="float32") - C_reindex_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") - for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): - for ax0_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): - for ax2_0_0 in T.serial(2): - for ax0_ax1_fused in T.serial(8192): - with T.block("A_reindex_shared"): - v0 = T.axis.spatial(128, ax0_ax1_fused // 64) - v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) - T.reads(A[v0, v1]) - T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - A_reindex_shared[v0, v1] = A[v0, v1] - for ax0_ax1_fused in T.serial(8192): - with T.block("B_reindex_shared"): - v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) - v1 = T.axis.spatial(128, ax0_ax1_fused % 128) - T.reads(B[v0, v1]) - T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - B_reindex_shared[v0, v1] = B[v0, v1] - for ax2_0_1 in T.serial(2): - for ax0_0, ax1_0 in T.grid(1, 2): - with T.block("A_reindex_shared_wmma.matrix_a_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0) - v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("A_reindex_shared_wmma.matrix_a"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_0, ax1_0 in T.grid(2, 4): - with T.block("B_reindex_shared_wmma.matrix_b_o"): - v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("B_reindex_shared_wmma.matrix_b"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 4, 2, 1, 1): - with T.block("C_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0_3 + ax0_0_4) - v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0_3) - v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) - with T.init(): - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_init"): - v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads() - T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) - for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): - with T.block("C"): - v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 4): - with T.block("C_reindex_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) - T.reads(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_global"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for i0, i1 in T.grid(128, 128): - with T.block("compute"): - i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) - T.reads(C[i0_1, i1_1]) - T.writes(compute[i0_1, i1_1]) - compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) - # fmt: on - decision_0 = [ - ("SamplePerfectTile", [1, 1, 8, 1, 1]), - ("SamplePerfectTile", [1, 1, 2, 4, 1]), - ("SamplePerfectTile", [2, 2, 2]), - ("SampleCategorical", 0), - ("SampleCategorical", 0), - ] - mod = te.create_prim_func( - te_workload.matmul_relu( - n=128, - m=128, - k=128, - in_dtype="float16", - out_dtype="float32", - ) - ) - actual = generate_design_space( - kind="cuda", - mod=mod, - target=tvm.target.Target("cuda"), - types=None, - sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] - + get_rules("cuda", ms.schedule_rule.AutoInline), - ) - check_sketches( - mod, - sketches=actual, - expected_mods=[matmul_relu_global_0], - expected_decisions=[decision_0], - ) - - def test_matmul_relu_non_tensorizable(): # expected to do nothing on non-tensorizable workloads mod = te.create_prim_func( @@ -842,7 +741,7 @@ def test_matmul_relu_non_tensorizable(): mod=mod, target=tvm.target.Target("cuda"), types=None, - sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + get_rules("cuda", ms.schedule_rule.AutoInline), ) tvm.ir.assert_structural_equal(mod, sch.mod["main"]) @@ -856,40 +755,40 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body # with T.block("root") - C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") - C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") - A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") - A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") - B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + C_reindex_shared = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + B_reindex_shared = T.alloc_buffer((128, 128), "float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer((128, 128), "float16", scope="wmma.matrix_b") for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): - for ax2_0_0 in T.serial(1): - for ax0_ax1_fused in T.serial(4096): + for ax2_0_0 in range(1): + for ax0_ax1_fused in range(4096): with T.block("A_reindex_shared"): v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) v1 = T.axis.spatial(128, ax0_ax1_fused % 128) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) - A_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, A[v0, v1], T.float16(0), dtype="float16") - for ax0_ax1_fused in T.serial(4096): + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 8}) + A_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, A[v0, v1], T.float16(0)) + for ax0_ax1_fused in range(4096): with T.block("B_reindex_shared"): v0 = T.axis.spatial(128, ax0_ax1_fused // 32) v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) - B_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, B[v0, v1], T.float16(0), dtype="float16") - for ax2_0_1 in T.serial(4): + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) + B_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, B[v0, v1], T.float16(0)) + for ax2_0_1 in range(4): for ax0_0, ax1_0 in T.grid(2, 2): with T.block("A_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) - T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) + T.reads(A_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("A_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -900,9 +799,9 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 with T.block("B_reindex_shared_wmma.matrix_b_o"): v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) + T.reads(B_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax0_1, ax1_1 in T.grid(16, 16): with T.block("B_reindex_shared_wmma.matrix_b"): v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) @@ -914,45 +813,56 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1 v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) - T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16:v2_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax0_1, ax1_1 in T.grid(16, 16): with T.block("C_init"): v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) T.reads() - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i_init, v1_i_init] = T.float32(0) for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): with T.block("C"): v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") - for ax0_0, ax1_0 in T.grid(2, 1): - with T.block("C_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) - v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("C_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(1024): - with T.block("C_reindex_shared"): - T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32 < 127) - v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) - T.reads(C_reindex_shared[v0, v1]) - T.writes(compute[v0, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":4}) - compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + T.reads(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] = C_reindex_shared_wmma_accumulator[v0_o // 2, v1_o, v0_o % 2, 0, v0_i, v1_i] + T.Cast("float32", A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i]) * T.Cast("float32", B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + for ax2 in range(2): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_fused) + v2 = T.axis.spatial(2, ax2 + ax2_1) + v3 = T.axis.spatial(1, ax3) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(C_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + C_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = C_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(4, ax0_0_0_ax1_0_0_fused // 2 + 0) + v1 = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256) + v2 = T.axis.spatial(2, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused % 512 // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127) + T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0)) + # fmt: on decision_0 = [ @@ -994,25 +904,25 @@ def test_conv_1x1(): @T.prim_func def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")) -> None: T.func_attr({"global_symbol": "main", "tir.noalias": True}) - conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 64], dtype="float32", scope="shared") - conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 64], dtype="float32", scope="wmma.accumulator") - PadInput_reindex_shared = T.alloc_buffer([256, 64], dtype="float16", scope="shared") - weight_reindex_shared = T.alloc_buffer([1, 1, 64, 64], dtype="float16", scope="shared") - PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 64], dtype="float16", scope="wmma.matrix_a") - weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([1, 1, 64, 64], dtype="float16", scope="wmma.matrix_b") + conv2d_nhwc_reindex_shared = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer((16, 4, 1, 1, 16, 16), scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer((256, 64), "float16", scope="shared") + weight_reindex_shared = T.alloc_buffer((1, 1, 64, 64), "float16", scope="shared") + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer((256, 64), "float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer((1, 1, 64, 64), "float16", scope="wmma.matrix_b") for ax2_0_0_ax3_0_0_fused in T.thread_binding(16, thread="blockIdx.y"): for ax2_0_1_ax3_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): for ax2_0_2_ax3_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): for ax0_0, ax1_0, ax4_0_0 in T.grid(1, 1, 1): - for ax0_ax1_fused in T.serial(1024): + for ax0_ax1_fused in range(1024): with T.block("PadInput_reindex_shared"): v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64) v1 = T.axis.spatial(64, ax0_ax1_fused % 64) T.reads(inputs[v0 // 256, v0 // 16, v0 % 16, v1]) T.writes(PadInput_reindex_shared[v0, v1]) - T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + T.block_attr({"buffer_dim_align": [[0, 0, 32, 8]], "meta_schedule.cooperative_fetch": 1}) PadInput_reindex_shared[v0, v1] = inputs[v0 // 256, v0 // 16, v0 % 16, v1] - for ax0_ax1_ax2_ax3_fused in T.serial(2048): + for ax0_ax1_ax2_ax3_fused in range(2048): with T.block("weight_reindex_shared"): v0 = T.axis.spatial(1, 0) v1 = T.axis.spatial(1, 0) @@ -1020,16 +930,16 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v3 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_ax2_ax3_fused % 32) T.reads(weight[v0, v1, v2, v3]) T.writes(weight_reindex_shared[v0, v1, v2, v3]) - T.block_attr({"buffer_dim_align":[[0, 2, 32, 8]], "meta_schedule.cooperative_fetch":4}) + T.block_attr({"buffer_dim_align": [[0, 2, 32, 8]], "meta_schedule.cooperative_fetch": 4}) weight_reindex_shared[v0, v1, v2, v3] = weight[v0, v1, v2, v3] for ax0_1, ax1_1, ax4_0_1 in T.grid(1, 1, 1): for ax0_0_1, ax1_0_1 in T.grid(1, 4): with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0_1) v1_o = T.axis.spatial(4, ax1_0_1) - T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a_shared"}) + T.reads(PadInput_reindex_shared[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16:v0_o * 16 + 16, v1_o * 16:v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_a_shared"}) for ax0_1_1, ax1_1_1 in T.grid(16, 16): with T.block("PadInput_reindex_shared_wmma.matrix_a"): v0_i, v1_i = T.axis.remap("SS", [ax0_1_1, ax1_1_1]) @@ -1040,9 +950,9 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( with T.block("weight_reindex_shared_wmma.matrix_b_o"): v0, v1, v2_o = T.axis.remap("SSS", [ax0, ax1, ax2_0]) v3_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0) - T.reads(weight_reindex_shared[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b_shared"}) + T.reads(weight_reindex_shared[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0, v1, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_load_16x16x16_f16_b_shared"}) for ax2_1, ax3_1 in T.grid(16, 16): with T.block("weight_reindex_shared_wmma.matrix_b"): v2_i, v3_i = T.axis.remap("SS", [ax2_1, ax3_1]) @@ -1056,44 +966,53 @@ def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer( v2_o = T.axis.spatial(16, ax2_0_4 + ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax2_0_3) v3_o = T.axis.spatial(4, ax3_0_4 + ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax3_0_3) v4_o = T.axis.reduce(4, ax4_0_0 * 4 + ax4_0_1 * 4 + ax4_0_2) - T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 : v2_o * 16 + 16, v4_o * 16 : v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 : v4_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 : v2_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16:v2_o * 16 + 16, v4_o * 16:v4_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16:v4_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init": "wmma_fill_16x16x16_f32", "warp_execution": 1}) with T.init(): for ax2_1, ax3_1 in T.grid(16, 16): with T.block("conv2d_nhwc_init"): v2_i_init, v3_i_init = T.axis.remap("SS", [ax2_1, ax3_1]) T.reads() - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init]) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i_init, v3_o * 16 + v3_i_init] = T.float32(0) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init, v3_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i_init, v3_i_init] = T.float32(0) for ax2_1, ax3_1, ax4_1 in T.grid(16, 16, 16): with T.block("conv2d_nhwc"): v2_i, v3_i, v4_i = T.axis.remap("SSR", [ax2_1, ax3_1, ax4_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) - T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i]) - T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) - conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o * 16 + v2_i, v3_o * 16 + v3_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i], "float32") - for ax0_0, ax1_0 in T.grid(1, 1): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): - v0_o = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused + ax0_0) - v1_o = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax2_0_2_ax3_0_2_fused + ax1_0) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) - T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) - for ax0_1, ax1_1 in T.grid(16, 16): - with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): - v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) - T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) - conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] - for ax0_ax1_fused in T.serial(512): - with T.block("conv2d_nhwc_reindex_shared"): - v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 32) - v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_fused % 32) - T.reads(conv2d_nhwc_reindex_shared[v0, v1]) - T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1]) - T.block_attr({"meta_schedule.cooperative_fetch":2}) - conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i], PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i], weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v2_o, v3_o, 0, 0, v2_i, v3_i] + T.Cast("float32", PadInput_reindex_shared_wmma_matrix_a[v2_o * 16 + v2_i, v4_o * 16 + v4_i]) * T.Cast("float32", weight_reindex_shared_wmma_matrix_b[v0, v1, v4_o * 16 + v4_i, v3_o * 16 + v3_i]) + for ax2 in range(1): + for ax0_ax1_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax0_ax1_fused) + v2, v3 = T.axis.remap("SS", [ax2_1, ax3]) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, 0:16, 0:16]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, 0:16, 0:16]) + T.block_attr({"meta_schedule.auto_tensorize": "wmma_store_16x16x16_f32_shared"}) + for ax4, ax5 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v4_i, v5_i = T.axis.remap("SS", [ax4, ax5]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i]) + T.writes(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i]) + conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4_i, v5_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0, v1, v2, v3, v4_i, v5_i] + for ax0_ax1_ax3_ax4_ax5_fused in range(512): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(16, ax2_0_0_ax3_0_0_fused // 2 * 2 + ax2_0_1_ax3_0_1_fused) + v1 = T.axis.spatial(4, ax2_0_0_ax3_0_0_fused % 2 * 2 + ax0_ax1_ax3_ax4_ax5_fused // 256) + v2 = T.axis.spatial(1, ax2) + v3 = T.axis.spatial(1, 0) + v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16) + v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16) + T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]) + T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16]) + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5] # fmt: on decision_0 = [ From 27e2db925c0ada4c0317293cfe4c9f5c02f61de3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 3 Mar 2023 14:59:54 -0800 Subject: [PATCH 11/15] remove unused --- .../multi_level_tiling_tensor_core.cc | 38 ------------------- 1 file changed, 38 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 8d6c242a9bcd..054ca54515ef 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -342,44 +342,6 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa }); sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map); - // // Tile by the fragment shape - // sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, - // tir::IndexMap::FromFunc(buffer_ndim, [&](const Array& indices) { - // Array result; - // result.reserve(indices.size() + 2); - // for (int i = 0; i < num_higher_dims; ++i) { - // result.push_back(indices[i]); - // } - // const auto& m = indices[num_higher_dims]; - // const auto& n = indices[num_higher_dims + 1]; - // result.push_back(floordiv(m, frag_shape_m)); - // result.push_back(floordiv(n, frag_shape_n)); - // result.push_back(floormod(m, frag_shape_m)); - // result.push_back(floormod(n, frag_shape_n)); - // return result; - // })); - - // // Tile by the number of fragments - // sch->TransformLayout( - // state->block_rv, 0, tir::BufferIndexType::kWrite, - // tir::IndexMap::FromFunc(buffer_ndim + 2, [&](const Array& indices) { - // Array result; - // result.reserve(indices.size() + 2); - // for (int i = 0; i < num_higher_dims; ++i) { - // result.push_back(indices[i]); - // } - // const auto& m = indices[num_higher_dims]; - // const auto& n = indices[num_higher_dims + 1]; - // result.push_back(floordiv(m, warp_num_frag_m)); - // result.push_back(floordiv(n, warp_num_frag_n)); - // result.push_back(floormod(m, warp_num_frag_m)); - // result.push_back(floormod(n, warp_num_frag_n)); - // // The last two indices are the fragment element indices - // result.push_back(indices[num_higher_dims + 2]); - // result.push_back(indices[num_higher_dims + 3]); - // return result; - // })); - return {state}; } From abac67c851c4acec16b8c8b7362805eeddaae55b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 3 Mar 2023 15:02:18 -0800 Subject: [PATCH 12/15] rebase --- .../schedule_rule/multi_level_tiling_tensor_core.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 054ca54515ef..1f9945022b66 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -340,7 +340,8 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa result.push_back(accum_n); return result; }); - sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map); + sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map, + /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); return {state}; } @@ -651,7 +652,8 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( state->sch->state(), GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); buffer_sub_index_map.Set(lhs_buffer, sub_index_map); - state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt); + state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, + /*pad_value=*/NullOpt, /*assume_injective_transform=*/true); }; for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { From 67674759345120e36f2c0d680bc66b1e067c7295 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 3 Mar 2023 15:19:49 -0800 Subject: [PATCH 13/15] only use json format for roundtrip --- python/tvm/meta_schedule/testing/space_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/testing/space_generation.py b/python/tvm/meta_schedule/testing/space_generation.py index 0b7072b65afe..45cd6659b6e0 100644 --- a/python/tvm/meta_schedule/testing/space_generation.py +++ b/python/tvm/meta_schedule/testing/space_generation.py @@ -88,7 +88,7 @@ def _find_match_sketch_id( decisions=new_decisions, ).apply_to_schedule(sch, remove_postproc=True) if structural_equal(sch.mod, expected_mod): - verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask) + verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask, text_format="json") return sketch_id return None From aefb25cb7f47df699376664bcc56c665b54603ad Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 3 Mar 2023 16:48:03 -0800 Subject: [PATCH 14/15] lint --- src/meta_schedule/schedule_rule/multi_level_tiling.cc | 1 - src/meta_schedule/schedule_rule/multi_level_tiling.h | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 8b7d613563f1..0312c100b51b 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -240,7 +240,6 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { for (int j = 0; j < n_tiles; ++j) { tiles[idx->at(j)].push_back(splits[j]); tile_factors[idx->at(j)].push_back(factors[j]); - // Array& a=state->tile_size[idx->at(j)];//.push_back(factors[j]); } } } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index dc679b52c50e..41b3ca9f26f3 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -94,11 +94,12 @@ class StateNode : public Object { tir::BlockRV block_rv; /*! \brief The loop tiles */ Array> tiles; + /*! \brief The factors of the loop tiles. */ + Array> tile_factors; /*! \brief The mapping from buffer index to read cache block. */ std::unordered_map read_reuse; /*! \brief The mapping from buffer index to write cache block. */ std::unordered_map write_reuse; - Array> tile_factors; /*! * \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that From cbad45b8727cde013698ae548ff353da84b0ad53 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 4 Mar 2023 08:53:38 -0500 Subject: [PATCH 15/15] Update src/tir/schedule/ir_comparator.h Co-authored-by: Siyuan Feng --- src/tir/schedule/ir_comparator.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index b37fb6654670..debf0f946e28 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -104,7 +104,8 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool is_scope_block = true; /*! \brief The arithmetic analyzer for comparing LHS and RHS */ arith::Analyzer analyzer_; - /*! \brief The arithmetic analyzer for simplifying expressions on LHS. + /*! + * \brief The arithmetic analyzer for simplifying expressions on LHS. * This analyzer only contains the domains of the iterators on LHS. */ arith::Analyzer lhs_analyzer_;