Skip to content

Commit

Permalink
fix compile on mac
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Feb 24, 2023
1 parent 4e9c48f commit 112c99b
Showing 1 changed file with 28 additions and 22 deletions.
50 changes: 28 additions & 22 deletions src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,28 +314,34 @@ std::vector<State> MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa
ICHECK_GE(buffer_ndim, 2);
int num_higher_dims = buffer_ndim - 2;

sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite,
tir::IndexMap::FromFunc(buffer_ndim, [&](const Array<tir::Var>& indices) {
Array<PrimExpr> result;
result.reserve(indices.size() + 4);
for (int i = 0; i < num_higher_dims; ++i) {
result.push_back(indices[i]);
}
const auto& m = indices[num_higher_dims];
const auto& n = indices[num_higher_dims + 1];
auto accum_m = floormod(m, frag_shape_m);
auto accum_n = floormod(n, frag_shape_n);
auto outer_m = floordiv(m, frag_shape_m);
auto outer_n = floordiv(n, frag_shape_n);

result.push_back(floordiv(outer_m, warp_num_frag_m));
result.push_back(floordiv(outer_n, warp_num_frag_n));
result.push_back(floormod(outer_m, warp_num_frag_m));
result.push_back(floormod(outer_n, warp_num_frag_n));
result.push_back(accum_m);
result.push_back(accum_n);
return result;
}));
auto index_map =
tir::IndexMap::FromFunc(buffer_ndim,
// frag_shape_m and frag_shape_n are structural bindings that cannot
// not be automatically captured until c++20
[&, frag_shape_m = frag_shape_m,
frag_shape_n = frag_shape_n](const Array<tir::Var>& indices) {
Array<PrimExpr> result;
result.reserve(indices.size() + 4);
for (int i = 0; i < num_higher_dims; ++i) {
result.push_back(indices[i]);
}
const auto& m = indices[num_higher_dims];
const auto& n = indices[num_higher_dims + 1];
auto accum_m = floormod(m, frag_shape_m);
auto accum_n = floormod(n, frag_shape_n);
auto outer_m = floordiv(m, frag_shape_m);
auto outer_n = floordiv(n, frag_shape_n);

result.push_back(floordiv(outer_m, warp_num_frag_m));
result.push_back(floordiv(outer_n, warp_num_frag_n));
result.push_back(floormod(outer_m, warp_num_frag_m));
result.push_back(floormod(outer_n, warp_num_frag_n));
result.push_back(accum_m);
result.push_back(accum_n);
return result;
};
sch->TransformLayout(
state->block_rv, 0, tir::BufferIndexType::kWrite, index_map);

// // Tile by the fragment shape
// sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite,
Expand Down

0 comments on commit 112c99b

Please sign in to comment.