diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index edf538fbfd812..0ee9d94c14940 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -314,28 +314,34 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa ICHECK_GE(buffer_ndim, 2); int num_higher_dims = buffer_ndim - 2; - sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite, - tir::IndexMap::FromFunc(buffer_ndim, [&](const Array& indices) { - Array result; - result.reserve(indices.size() + 4); - for (int i = 0; i < num_higher_dims; ++i) { - result.push_back(indices[i]); - } - const auto& m = indices[num_higher_dims]; - const auto& n = indices[num_higher_dims + 1]; - auto accum_m = floormod(m, frag_shape_m); - auto accum_n = floormod(n, frag_shape_n); - auto outer_m = floordiv(m, frag_shape_m); - auto outer_n = floordiv(n, frag_shape_n); - - result.push_back(floordiv(outer_m, warp_num_frag_m)); - result.push_back(floordiv(outer_n, warp_num_frag_n)); - result.push_back(floormod(outer_m, warp_num_frag_m)); - result.push_back(floormod(outer_n, warp_num_frag_n)); - result.push_back(accum_m); - result.push_back(accum_n); - return result; - })); + auto index_map = + tir::IndexMap::FromFunc(buffer_ndim, + // frag_shape_m and frag_shape_n are structural bindings that cannot + // not be automatically captured until c++20 + [&, frag_shape_m = frag_shape_m, + frag_shape_n = frag_shape_n](const Array& indices) { + Array result; + result.reserve(indices.size() + 4); + for (int i = 0; i < num_higher_dims; ++i) { + result.push_back(indices[i]); + } + const auto& m = indices[num_higher_dims]; + const auto& n = indices[num_higher_dims + 1]; + auto accum_m = floormod(m, frag_shape_m); + auto accum_n = floormod(n, frag_shape_n); + auto outer_m = floordiv(m, frag_shape_m); + auto outer_n = floordiv(n, frag_shape_n); + + result.push_back(floordiv(outer_m, warp_num_frag_m)); + result.push_back(floordiv(outer_n, warp_num_frag_n)); + result.push_back(floormod(outer_m, warp_num_frag_m)); + result.push_back(floormod(outer_n, warp_num_frag_n)); + result.push_back(accum_m); + result.push_back(accum_n); + return result; + }; + sch->TransformLayout( + state->block_rv, 0, tir::BufferIndexType::kWrite, index_map); // // Tile by the fragment shape // sch->TransformLayout(state->block_rv, 0, tir::BufferIndexType::kWrite,