From 964398e27406684c10232a2b759412abd7dbbabd Mon Sep 17 00:00:00 2001 From: dchigarev Date: Tue, 30 Jul 2024 15:55:30 +0000 Subject: [PATCH] Aling 'linalg-to-xegpu' pass with patched XeGPU dialect Signed-off-by: dchigarev --- lib/gc/Transforms/GPU/LinalgToXeGPU.cpp | 39 ++++++++++++++++++------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index bc4326abc..eacb5933b 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -597,12 +597,22 @@ static SmallVector updateTilesOffsets(PatternRewriter &rewriter, Location loc, ValueRange tiles, ArrayRef offsets) { SmallVector updatedTiles; + // convert static offsets to dynamic because of this IMEX bug: + // https://github.com/intel/mlir-extensions/issues/815 + std::vector dynOffsets; + for (auto &x : offsets) { + Value offset = rewriter.create(loc, x); + dynOffsets.push_back(offset); + } + ValueRange newOffsets{dynOffsets}; for (auto tile : tiles) { - auto updatedTile = - rewriter - .create(loc, tile.getType(), tile, - /*offsets=*/ValueRange{}, offsets) - .getResult(); + auto updatedTile = rewriter + .create( + loc, tile.getType(), tile, + /*offsets=*/newOffsets, + SmallVector{ShapedType::kDynamic, + ShapedType::kDynamic}) + .getResult(); updatedTiles.push_back(updatedTile); } @@ -648,11 +658,17 @@ static SmallVector createDescriptorTiles(PatternRewriter &rewriter, SmallVector tiles; for (int i = 0; i < loadShape[0]; i += descTile[0]) { + // convert static offsets to dynamic because of this IMEX bug: + // https://github.com/intel/mlir-extensions/issues/815 + Value newRowOffs = rewriter.create(loc, i); for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) { + Value newColOffs = rewriter.create(loc, j); auto tile = rewriter .create( loc, descType, rootTile, - /*offsets=*/ValueRange{}, SmallVector{i, j}) + /*offsets=*/ValueRange{newRowOffs, newColOffs}, + SmallVector{ShapedType::kDynamic, + ShapedType::kDynamic}) .getResult(); tiles.push_back(tile); } @@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles, VectorType vecLoadType = VectorType::get(tileType.getShape(), tileType.getElementType()); - UnitAttr vnniAxisAttr = nullptr; + mlir::UnitAttr packedAttr = nullptr; if (vnniConf) { - vnniAxisAttr = UnitAttr::get(rewriter.getContext()); vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(), *vnniConf); + packedAttr = mlir::UnitAttr::get(rewriter.getContext()); } - + IntegerAttr transpose_bit = nullptr; SmallVector loadVec; for (auto tile : loadTiles) { + auto loadOp = rewriter.create( - loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr, + loc, vecLoadType, tile, packedAttr, transpose, transpose_bit, /*l1_hint=*/hint, /*l2_hint=*/hint, /*l3_hint=*/hint); loadVec.push_back(loadOp); @@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, // Load A sub-tiles. SmallVector loadVecA = - loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA); + loadNdDescTiles(rewriter, loc, tilesA, readCacheHint); auto tileTypeA = cast(tilesA[0].getType()); // Load B sub-tiles.