Skip to content

Commit

Permalink
Aling 'linalg-to-xegpu' pass with patched XeGPU dialect
Browse files Browse the repository at this point in the history
Signed-off-by: dchigarev <dmitry.chigarev@intel.com>
  • Loading branch information
dchigarev committed Jul 31, 2024
1 parent 1b9d6ac commit 964398e
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions lib/gc/Transforms/GPU/LinalgToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,22 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
Location loc, ValueRange tiles,
ArrayRef<int64_t> offsets) {
SmallVector<Value> updatedTiles;
// convert static offsets to dynamic because of this IMEX bug:
// https://github.com/intel/mlir-extensions/issues/815
std::vector<Value> dynOffsets;
for (auto &x : offsets) {
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, x);
dynOffsets.push_back(offset);
}
ValueRange newOffsets{dynOffsets};
for (auto tile : tiles) {
auto updatedTile =
rewriter
.create<xegpu::UpdateNdOffsetOp>(loc, tile.getType(), tile,
/*offsets=*/ValueRange{}, offsets)
.getResult();
auto updatedTile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, tile.getType(), tile,
/*offsets=*/newOffsets,
SmallVector<int64_t>{ShapedType::kDynamic,
ShapedType::kDynamic})
.getResult();
updatedTiles.push_back(updatedTile);
}

Expand Down Expand Up @@ -648,11 +658,17 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,

SmallVector<Value> tiles;
for (int i = 0; i < loadShape[0]; i += descTile[0]) {
// convert static offsets to dynamic because of this IMEX bug:
// https://github.com/intel/mlir-extensions/issues/815
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
auto tile = rewriter
.create<xegpu::UpdateNdOffsetOp>(
loc, descType, rootTile,
/*offsets=*/ValueRange{}, SmallVector<int64_t>{i, j})
/*offsets=*/ValueRange{newRowOffs, newColOffs},
SmallVector<int64_t>{ShapedType::kDynamic,
ShapedType::kDynamic})
.getResult();
tiles.push_back(tile);
}
Expand Down Expand Up @@ -732,17 +748,18 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,

VectorType vecLoadType =
VectorType::get(tileType.getShape(), tileType.getElementType());
UnitAttr vnniAxisAttr = nullptr;
mlir::UnitAttr packedAttr = nullptr;
if (vnniConf) {
vnniAxisAttr = UnitAttr::get(rewriter.getContext());
vecLoadType = getVnniVector(tileType.getShape(), tileType.getElementType(),
*vnniConf);
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
}

IntegerAttr transpose_bit = nullptr;
SmallVector<Value> loadVec;
for (auto tile : loadTiles) {

auto loadOp = rewriter.create<xegpu::LoadNdOp>(
loc, vecLoadType, tile, vnniAxisAttr, transpose, nullptr,
loc, vecLoadType, tile, packedAttr, transpose, transpose_bit,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
loadVec.push_back(loadOp);
Expand Down Expand Up @@ -1057,7 +1074,7 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,

// Load A sub-tiles.
SmallVector<Value> loadVecA =
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint, vnniConfA);
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint);
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());

// Load B sub-tiles.
Expand Down

0 comments on commit 964398e

Please sign in to comment.