Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Tile and pack intermediate output for CUDA TensorCore #14108

Merged
merged 15 commits into from
Mar 6, 2023
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/testing/space_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _find_match_sketch_id(
decisions=new_decisions,
).apply_to_schedule(sch, remove_postproc=True)
if structural_equal(sch.mod, expected_mod):
verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask)
verify_trace_roundtrip(sch=sch, mod=mod, debug_mask=debug_mask, text_format="json")
return sketch_id
return None

Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
Expand Down
13 changes: 9 additions & 4 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
return results;
}

Array<tir::LoopRV> MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop,
int n_tiles) const {
std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingNode::SplitLoop(
const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const {
Array<tir::ExprRV> factors = sch->SamplePerfectTile(
/*loop=*/loop,
/*n=*/n_tiles,
/*max_innermost_factor=*/max_innermost_factor);
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop,
/*factors=*/{factors.begin(), factors.end()});
return splits;
return {factors, splits};
}

std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
Expand All @@ -207,6 +207,9 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
// Step 2. For each loop axis, tile it
int64_t spatial_loop_product = 1;
std::vector<Array<LoopRV>> tiles(s_indices_.size() + r_indices_.size());
state->tile_factors.resize(tiles.size());
std::vector<Array<tir::ExprRV>> tile_factors;
tile_factors.resize(tiles.size());
for (int i = 0, n = loops.size(); i < n; ++i) {
LoopRV loop = loops[i];
const std::vector<int>* idx = nullptr;
Expand All @@ -231,14 +234,16 @@ std::vector<State> MultiLevelTilingNode::TileLoopNest(State state) const {
if (n_tiles == 1) {
tiles[idx->at(0)].push_back(loop);
} else {
auto splits = SplitLoop(sch, block_rv, loop, n_tiles);
auto [factors, splits] = SplitLoop(sch, block_rv, loop, n_tiles);

// Put every tile to its slot
for (int j = 0; j < n_tiles; ++j) {
tiles[idx->at(j)].push_back(splits[j]);
tile_factors[idx->at(j)].push_back(factors[j]);
}
}
}
state->tile_factors = std::move(tile_factors);
// Step 3. Reorder to organize the tiles
sch->Reorder(support::ConcatArrayList<LoopRV>(tiles.begin(), tiles.end()));
// Step 4. Bind the tiles to threads
Expand Down
8 changes: 6 additions & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class StateNode : public Object {
tir::BlockRV block_rv;
/*! \brief The loop tiles */
Array<Array<tir::LoopRV>> tiles;
/*! \brief The factors of the loop tiles. */
Array<Array<tir::ExprRV>> tile_factors;
/*! \brief The mapping from buffer index to read cache block. */
std::unordered_map<int, tir::BlockRV> read_reuse;
/*! \brief The mapping from buffer index to write cache block. */
Expand Down Expand Up @@ -163,8 +165,10 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states);

virtual Array<tir::LoopRV> SplitLoop(const tir::Schedule& sch, tir::BlockRV block,
tir::LoopRV loop, int n_tiles) const;
virtual std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> SplitLoop(const tir::Schedule& sch,
tir::BlockRV block,
tir::LoopRV loop,
int n_tiles) const;

// Annotate a block to use cooperative fetching
void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const;
Expand Down
176 changes: 163 additions & 13 deletions src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/tir/op.h>

#include <algorithm>
#include <utility>
Expand Down Expand Up @@ -124,6 +125,9 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
private:
// SubRule: Add tensorization-related transformations
inline std::vector<State> TransformForTensorization(TensorCoreState state) const;
// Subrule: Transform the layout of the output. This is necessary for efficient cache write the
// output in the shared memory.
std::vector<State> TransformIntermediateOutputLayout(TensorCoreState state);
// Subrule: Add tensorized load
inline std::vector<State> AddReadReuseTensorCore(TensorCoreState state) const;
// Subrule: Add tensorized store
Expand Down Expand Up @@ -225,6 +229,9 @@ std::vector<State> MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<Sta
return TransformForTensorization(Downcast<TensorCoreState>(state));
});
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); });
states = SubRule(std::move(states), [&](State state) {
return TransformIntermediateOutputLayout(Downcast<TensorCoreState>(state));
});
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); });
states = SubRule(std::move(states), [&](State state) {
return AddWriteReuseTensorCore(Downcast<TensorCoreState>(state));
Expand All @@ -248,25 +255,162 @@ void MultiLevelTilingTensorCoreNode::TileAndAnnotateTensorize(Schedule* sch,
(*sch)->Annotate(blockized_outer, tir::attr::meta_schedule_auto_tensorize, intrin_name);
}

std::vector<State> MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLayout(
TensorCoreState state) {
// Transform the intermediate output to packed layout
// [..., warp_m, warp_n, accum_frag_m, accum_frag_n, accum_elem_m, accum_elem_n]
// where warp_m, warp_n are thread indices bound to the warp id, accum_frag_m, accum_frag_n are
// the index of the fragments in each warp, accum_elem_m, accum_elem_n are the index of the
// elements in each accumulator fragment.

// Get the shape of the wmma accumulator
auto [frag_shape_m, frag_shape_n] = [&]() {
tir::Block intrin_block =
Downcast<tir::BlockRealize>(
tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body)
->block;
tir::For loop_m = Downcast<tir::For>(intrin_block->body);
tir::For loop_n = Downcast<tir::For>(loop_m->body);
return std::make_tuple(loop_m->extent, loop_n->extent);
}();

// Get the tile index of the warp id (i.e. threadIdx.y)
auto it = std::find(tile_binds.begin(), tile_binds.end(), "threadIdx.y");
ICHECK(it != tile_binds.end());
auto tile_index_warp_id = std::distance(tile_binds.begin(), it);

// Get the extent of loop indicated by `loop_idx` inside the warp scope.
// For example, after spatial loops i, j are tiled, we will have
// tile_factors = ((i0, j0), (i1, j1), ..., (in, jn))
// This function computes the product of tile_factors[i][loop_idx] for i > tile_index_warp_id.
// `loop_idx` can be negative, in which case it is counted from the end.
auto f_get_inner_tile_product = [&](int loop_idx) {
Array<tir::ExprRV> factors;
for (int i = tile_index_warp_id + 1; i < static_cast<int>(s_indices_.size()); ++i) {
auto s_factors = state->tile_factors[s_indices_[i]];
if (loop_idx < 0) {
loop_idx += s_factors.size();
}
factors.push_back(s_factors[loop_idx]);
}
ICHECK(!factors.empty());
if (factors.size() == 1) {
return factors[0];
}
auto result = factors[0];
for (int i = 1; i < static_cast<int>(factors.size()); ++i) {
result = result * factors[i];
}
return result;
};

// Compute the number of output fragment of each warp
auto warp_num_frag_m = f_get_inner_tile_product(-2);
auto warp_num_frag_n = f_get_inner_tile_product(-1);

Schedule& sch = state->sch;
int buffer_ndim = static_cast<int>(sch->Get(state->block_rv)->writes[0]->buffer->shape.size());
// The dimension of the buffer should be larger or same as that of the tensor intrin.
ICHECK_GE(buffer_ndim, 2);
int num_higher_dims = buffer_ndim - 2;

auto index_map =
tir::IndexMap::FromFunc(buffer_ndim,
// frag_shape_m and frag_shape_n are structural bindings that cannot
// not be automatically captured until c++20
[&, frag_shape_m = frag_shape_m,
frag_shape_n = frag_shape_n](const Array<tir::Var>& indices) {
Array<PrimExpr> result;
result.reserve(indices.size() + 4);
for (int i = 0; i < num_higher_dims; ++i) {
result.push_back(indices[i]);
}
const auto& m = indices[num_higher_dims];
const auto& n = indices[num_higher_dims + 1];
auto accum_m = floormod(m, frag_shape_m);
auto accum_n = floormod(n, frag_shape_n);
auto outer_m = floordiv(m, frag_shape_m);
auto outer_n = floordiv(n, frag_shape_n);

result.push_back(floordiv(outer_m, warp_num_frag_m));
result.push_back(floordiv(outer_n, warp_num_frag_n));
result.push_back(floormod(outer_m, warp_num_frag_m));
result.push_back(floormod(outer_n, warp_num_frag_n));
result.push_back(accum_m);
result.push_back(accum_n);
return result;
});
sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, index_map,
/*pad_value=*/NullOpt, /*assume_injective_transform=*/true);

return {state};
}

std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
TensorCoreState state) const {
// Add the cache write stage for Tensor Core
int level = r_indices_.front() - 1;
const LoopRV& loop = state->tiles[level].back();
Schedule& sch = state->sch;
auto cache_write = sch->CacheWrite(state->block_rv, 0, "wmma.accumulator");
sch->ReverseComputeAt(cache_write, loop, true);

if (state->write_reuse.count(0)) {
// Fuse the iterators of the cache_write
Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
ICHECK_GT(buffer_loops.size(), 2);
sch->Fuse(Array<LoopRV>{buffer_loops.end() - 2, // The src shmem is always 2D
buffer_loops.end()});
AnnotateCooperativeFetching(&sch, state->write_reuse[0]);

// The compute block has been tiled by the warp shape and the fragment shape.
// We need to bind the cache write block (from the accumulator to the shared memory) to the warp
// id. The schedule is as follows:
//
// After adding cache write for wmma.accumulator, we will have
// for i0, j0, i1, j1, accum_m, accum_n:
// shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n]
// for i0', j0', i1', j1', accum_m', accum_n':
// global_mem[i0', j0', i1', j1', accum_m', accum_n'] =
// shared_mem[i0', j0', i1', j1', accum_m', accum_n']
// where i0' and j0' are already bound to the block id and warp id.
//
// To reduce the shared memory usage and allow efficient data movement, we will apply
// transformations to generate the following schedule:
//
// for i1':
// for i0_j0 (fused and bound to threadIdx.y):
// for j1, accum_m, accum_n:
// shared_mem[i0, j0, i1, j1, accum_m, accum_n] = accum[i0, j0, i1, j1, accum_m, accum_n]
// for i0', j0', j1', accum_m', accum_n':
// global_mem[i0', j0', i1', j1', accum_m', accum_n'] =
// shared_mem[i0', j0', i1', j1', accum_m', accum_n']
//
// i1' is reordered to the outermost. This effectively allows only a row (i.e. loop i1') of the
// fragments are moved to the shared memory and then to the global memory each time.
// As a result, shared memory for the output will only have shape of [j1, accum_m, accum_n]
// instead of [i0 * i1 * accum_m, j0 * j1 * accum_n].

// Get the loops other than the innermost two loops (accum_m and accum_n).
auto f_get_loops = [&](const BlockRV& block_rv) -> std::array<LoopRV, 4> {
Array<LoopRV> buffer_loops = sch->GetLoops(block_rv);
ICHECK_GT(buffer_loops.size(), 6);
return {buffer_loops[buffer_loops.size() - 6], buffer_loops[buffer_loops.size() - 5],
buffer_loops[buffer_loops.size() - 4], buffer_loops[buffer_loops.size() - 3]};
};
{
const auto& [i0, j0, i1, j1] = f_get_loops(state->write_reuse[0]);
sch->Reorder({i1, i0, j0, j1});
sch->ComputeAt(cache_write, i1, true);
}
{
auto loops = f_get_loops(cache_write);
const auto& i0 = loops[0];
const auto& j0 = loops[1];
auto fused = sch->Fuse({i0, j0});
sch->Bind(fused, "threadIdx.y");
}

sch->ReverseComputeInline(state->tensor_core_reindex_store);
TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin);
auto loops = sch->GetLoops(cache_write);
auto blockized_store = sch->Blockize(loops[loops.size() - 2]);
sch->Annotate(blockized_store, tir::attr::meta_schedule_auto_tensorize,
state->intrin_group.store_intrin);

Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
ICHECK_GT(buffer_loops.size(), 5);
sch->Fuse(Array<LoopRV>{buffer_loops.end() - 5, // The src shmem is always 2D
buffer_loops.end()});
AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
return {state};
}

Expand Down Expand Up @@ -508,7 +652,8 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type);
auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region);
buffer_sub_index_map.Set(lhs_buffer, sub_index_map);
state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt);
state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map,
/*pad_value=*/NullOpt, /*assume_injective_transform=*/true);
};

for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) {
Expand Down Expand Up @@ -569,6 +714,11 @@ ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);

CHECK(node->reuse_write_.req == ReuseType::kMustReuse &&
runtime::StorageScope::Create(node->reuse_write_.scope).rank ==
runtime::StorageRank::kShared)
<< "ValueError: Shared memory write reuse must be enabled for MultiLevelTilingTensorCore.";

node->intrin_groups.reserve(intrin_groups.size());
for (const auto& intrin_group_config : intrin_groups) {
node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode {
return ScheduleRule(n);
}

Array<tir::LoopRV> SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const;
std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> SplitLoop(const Schedule& sch, BlockRV block,
LoopRV loop, int n_tiles) const;
};

Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv,
LoopRV loop_rv, int n_tiles) const {
std::pair<Array<tir::ExprRV>, Array<tir::LoopRV>> MultiLevelTilingWideVectorNode::SplitLoop(
const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const {
const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv));
const tir::StmtSRef block_sref = sch->GetSRef(block_rv);
const tir::BlockNode* block_node = block_sref->StmtAs<tir::BlockNode>();
Expand Down Expand Up @@ -99,12 +100,14 @@ Array<tir::LoopRV> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch
Array<tir::LoopRV> outer_splits = sch->Split(
/*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()});
outer_splits.push_back(inner_splits[1]);
return outer_splits;
outer_factors.push_back(PrimExpr(vec_len));
return {outer_factors, outer_splits};
} else {
Array<tir::ExprRV> factors(n_tiles - 1, PrimExpr(1));
factors.push_back(loop->extent);
return sch->Split(/*loop=*/loop_rv,
/*factors=*/{factors.begin(), factors.end()});
Array<tir::LoopRV> splits = sch->Split(/*loop=*/loop_rv,
/*factors=*/{factors.begin(), factors.end()});
return {factors, splits};
}
}
}
Expand Down
9 changes: 6 additions & 3 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ class BlockReadWriteDetector : public StmtExprVisitor {
Map<Var, Buffer> buffer_var_map_;
/*! \brief The target buffer var mapping to its matching */
std::unordered_map<const VarNode*, MatchBufferRegion> match_buffers_;
/*! \brief The analyzer for simplifying*/
arith::Analyzer analyzer_;

/*!
* \brief Update read/write buffers and regions with provided buffer and region
Expand Down Expand Up @@ -330,7 +328,12 @@ Array<BufferRegion> BlockReadWriteDetector::CollectRegions(
ICHECK_EQ(buffers[i]->shape.size(), regions[i].size());
for (size_t j = 0; j < regions[i].size(); j++) {
const tvm::arith::IntSet& range = regions[i][j];
region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j])));
if (range.IsSinglePoint()) {
PrimExpr min = range.min();
region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1)));
} else {
region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j])));
}
}
res.push_back(BufferRegion(buffers[i], region));
}
Expand Down
Loading