-
Notifications
You must be signed in to change notification settings - Fork 661
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[gpu] NFC: Move MMA schedule deduction to Common/GPU/ (#16480)
This prepares it to be shared by the LLVMGPU path.
- Loading branch information
1 parent
39108c4
commit 56725c5
Showing
7 changed files
with
230 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
113 changes: 113 additions & 0 deletions
113
compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h" | ||
#include "llvm/ADT/APInt.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "llvm/Support/MathExtras.h" | ||
#include "llvm/Support/raw_ostream.h" | ||
|
||
#define DEBUG_TYPE "iree-codegen-gpu-heuristics" | ||
|
||
using llvm::APIntOps::GreatestCommonDivisor; | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
std::optional<GPUMMASchedule> | ||
deduceMMASchedule(const GPUMatmulShapeType &problem, | ||
ArrayRef<GPUMatmulShapeType> intrinsics, | ||
const GPUMMAHeuristicSeeds &seeds) { | ||
for (const GPUMatmulShapeType &intrinsic : intrinsics) { | ||
if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType || | ||
problem.cType != intrinsic.cType) { | ||
continue; // Cannot use this intrinsic for mismatched types | ||
} | ||
|
||
if (problem.mSize % intrinsic.mSize != 0 || | ||
problem.nSize % intrinsic.nSize != 0 || | ||
problem.kSize % intrinsic.kSize != 0) { | ||
continue; // Cannot use this intrinsic for misaligned cases | ||
} | ||
|
||
int64_t mTotalTileCount = problem.mSize / intrinsic.mSize; | ||
int64_t nTotalTileCount = problem.nSize / intrinsic.nSize; | ||
|
||
int64_t remainingWarps = seeds.numSubgroupsPerWorkgroup; | ||
int64_t remainingTiles = seeds.numMNTilesPerSubgroup; | ||
// Assign more warps to the M dimension (used later) to balance thread | ||
// counts along X and Y dimensions. | ||
int64_t warpSqrt = 1ull | ||
<< (llvm::divideCeil(llvm::Log2_64(remainingWarps), 2)); | ||
int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2); | ||
|
||
int64_t mWarpCount = 0, nWarpCount = 0; | ||
int64_t mTileCount = 0, nTileCount = 0; | ||
|
||
// See if the square root can divide mTotalTileCount. If so it means we can | ||
// distribute to both dimensions evenly. Otherwise, try to distribute to N | ||
// and then M. | ||
if (mTotalTileCount > (warpSqrt * tileSqrt) && | ||
mTotalTileCount % (warpSqrt * tileSqrt) == 0) { | ||
mWarpCount = warpSqrt; | ||
mTileCount = tileSqrt; | ||
|
||
remainingWarps /= warpSqrt; | ||
remainingTiles /= tileSqrt; | ||
|
||
APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), | ||
APInt(64, remainingWarps)); | ||
nWarpCount = nGCD.getSExtValue(); | ||
nTotalTileCount /= nWarpCount; | ||
remainingWarps /= nWarpCount; | ||
|
||
nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), | ||
APInt(64, remainingTiles)); | ||
nTileCount = nGCD.getSExtValue(); | ||
} else { | ||
APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), | ||
APInt(64, remainingWarps)); | ||
nWarpCount = nGCD.getSExtValue(); | ||
nTotalTileCount /= nWarpCount; | ||
remainingWarps /= nWarpCount; | ||
|
||
nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount), | ||
APInt(64, remainingTiles)); | ||
nTileCount = nGCD.getSExtValue(); | ||
remainingTiles /= nTileCount; | ||
|
||
APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), | ||
APInt(64, remainingWarps)); | ||
mWarpCount = mGCD.getSExtValue(); | ||
mTotalTileCount /= mWarpCount; | ||
remainingWarps /= mWarpCount; | ||
|
||
mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount), | ||
APInt(64, remainingTiles)); | ||
mTileCount = mGCD.getSExtValue(); | ||
} | ||
|
||
const uint64_t kTotalTileCount = problem.kSize / intrinsic.kSize; | ||
APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCount), | ||
APInt(64, seeds.numKTilesPerSubgroup)); | ||
int64_t kTileCount = kGCD.getSExtValue(); | ||
|
||
LLVM_DEBUG({ | ||
llvm::dbgs() << "chosen MMA schedule:\n"; | ||
llvm::dbgs() << " intrinsic (M, N, K) = (" << intrinsic.mSize << ", " | ||
<< intrinsic.nSize << ", " << intrinsic.kSize << ")\n"; | ||
llvm::dbgs() << " subgroup count (M, N) = (" << mWarpCount << ", " | ||
<< nWarpCount << ")\n"; | ||
llvm::dbgs() << " subgroup tile count (M, N, K) = (" << mTileCount | ||
<< ", " << nTileCount << ", " << kTileCount << ")\n"; | ||
}); | ||
return GPUMMASchedule{intrinsic.mSize, intrinsic.nSize, intrinsic.kSize, | ||
mWarpCount, nWarpCount, mTileCount, | ||
nTileCount, kTileCount}; | ||
} | ||
return std::nullopt; | ||
} | ||
|
||
} // namespace mlir::iree_compiler |
52 changes: 52 additions & 0 deletions
52
compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "mlir/IR/Types.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
/// Struct containing information about a matmul's shape and type. | ||
struct GPUMatmulShapeType { | ||
int64_t mSize; | ||
int64_t nSize; | ||
int64_t kSize; | ||
Type aType; | ||
Type bType; | ||
Type cType; | ||
|
||
GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c) | ||
: mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {} | ||
}; | ||
|
||
/// Struct containing seed tile sizes for GPU MMA heuristics deduction logic. | ||
struct GPUMMAHeuristicSeeds { | ||
// The default number of subgroups to use per workgroup | ||
int64_t numSubgroupsPerWorkgroup; | ||
// The default number of tiles along M/N dimension to use per workgroup | ||
int64_t numMNTilesPerSubgroup; | ||
// The default number of tiles along K dimension to use per subgroup | ||
int64_t numKTilesPerSubgroup; | ||
}; | ||
|
||
struct GPUMMASchedule { | ||
int64_t mSize; // Native MMA size along M dimension | ||
int64_t nSize; // Native MMA size along N dimension | ||
int64_t kSize; // Native MMA size along K dimension | ||
int64_t mWarpCount; // Number of subgroups along M dimension | ||
int64_t nWarpCount; // Number of subgroups along N dimension | ||
int64_t mTileCount; // Number of tiles per subgroup along M dimension | ||
int64_t nTileCount; // Number of tiles per subgroup along N dimension | ||
int64_t kTileCount; // Number of tiles along K dimension | ||
}; | ||
|
||
/// Returns a schedule for using one of the given MMA |intrinsics| to target the | ||
/// input |problem|. Returns std::nullopt if we cannot find such a schedule. | ||
std::optional<GPUMMASchedule> | ||
deduceMMASchedule(const GPUMatmulShapeType &problem, | ||
ArrayRef<GPUMatmulShapeType> intrinsics, | ||
const GPUMMAHeuristicSeeds &seeds); | ||
|
||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.