Skip to content

Commit

Permalink
[gpu] NFC: Move MMA schedule deduction to Common/GPU/ (#16480)
Browse files Browse the repository at this point in the history
This prepares it to be shared by the LLVMGPU path.
  • Loading branch information
antiagainst authored Feb 19, 2024
1 parent 39108c4 commit 56725c5
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 130 deletions.
15 changes: 15 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,18 @@ iree_compiler_cc_library(
"@llvm-project//mlir:VectorTransforms",
],
)

iree_compiler_cc_library(
name = "GPUHeuristics",
srcs = [
"GPUHeuristics.cpp",
],
hdrs = [
"GPUHeuristics.h",
],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
14 changes: 14 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,18 @@ iree_cc_library(
PUBLIC
)

iree_cc_library(
NAME
GPUHeuristics
HDRS
"GPUHeuristics.h"
SRCS
"GPUHeuristics.cpp"
DEPS
LLVMSupport
MLIRIR
MLIRSupport
PUBLIC
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
113 changes: 113 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "iree-codegen-gpu-heuristics"

using llvm::APIntOps::GreatestCommonDivisor;

namespace mlir::iree_compiler {

std::optional<GPUMMASchedule>
deduceMMASchedule(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
const GPUMMAHeuristicSeeds &seeds) {
for (const GPUMatmulShapeType &intrinsic : intrinsics) {
if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType ||
problem.cType != intrinsic.cType) {
continue; // Cannot use this intrinsic for mismatched types
}

if (problem.mSize % intrinsic.mSize != 0 ||
problem.nSize % intrinsic.nSize != 0 ||
problem.kSize % intrinsic.kSize != 0) {
continue; // Cannot use this intrinsic for misaligned cases
}

int64_t mTotalTileCount = problem.mSize / intrinsic.mSize;
int64_t nTotalTileCount = problem.nSize / intrinsic.nSize;

int64_t remainingWarps = seeds.numSubgroupsPerWorkgroup;
int64_t remainingTiles = seeds.numMNTilesPerSubgroup;
// Assign more warps to the M dimension (used later) to balance thread
// counts along X and Y dimensions.
int64_t warpSqrt = 1ull
<< (llvm::divideCeil(llvm::Log2_64(remainingWarps), 2));
int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);

int64_t mWarpCount = 0, nWarpCount = 0;
int64_t mTileCount = 0, nTileCount = 0;

// See if the square root can divide mTotalTileCount. If so it means we can
// distribute to both dimensions evenly. Otherwise, try to distribute to N
// and then M.
if (mTotalTileCount > (warpSqrt * tileSqrt) &&
mTotalTileCount % (warpSqrt * tileSqrt) == 0) {
mWarpCount = warpSqrt;
mTileCount = tileSqrt;

remainingWarps /= warpSqrt;
remainingTiles /= tileSqrt;

APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
APInt(64, remainingWarps));
nWarpCount = nGCD.getSExtValue();
nTotalTileCount /= nWarpCount;
remainingWarps /= nWarpCount;

nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
APInt(64, remainingTiles));
nTileCount = nGCD.getSExtValue();
} else {
APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
APInt(64, remainingWarps));
nWarpCount = nGCD.getSExtValue();
nTotalTileCount /= nWarpCount;
remainingWarps /= nWarpCount;

nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
APInt(64, remainingTiles));
nTileCount = nGCD.getSExtValue();
remainingTiles /= nTileCount;

APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount),
APInt(64, remainingWarps));
mWarpCount = mGCD.getSExtValue();
mTotalTileCount /= mWarpCount;
remainingWarps /= mWarpCount;

mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount),
APInt(64, remainingTiles));
mTileCount = mGCD.getSExtValue();
}

const uint64_t kTotalTileCount = problem.kSize / intrinsic.kSize;
APInt kGCD = GreatestCommonDivisor(APInt(64, kTotalTileCount),
APInt(64, seeds.numKTilesPerSubgroup));
int64_t kTileCount = kGCD.getSExtValue();

LLVM_DEBUG({
llvm::dbgs() << "chosen MMA schedule:\n";
llvm::dbgs() << " intrinsic (M, N, K) = (" << intrinsic.mSize << ", "
<< intrinsic.nSize << ", " << intrinsic.kSize << ")\n";
llvm::dbgs() << " subgroup count (M, N) = (" << mWarpCount << ", "
<< nWarpCount << ")\n";
llvm::dbgs() << " subgroup tile count (M, N, K) = (" << mTileCount
<< ", " << nTileCount << ", " << kTileCount << ")\n";
});
return GPUMMASchedule{intrinsic.mSize, intrinsic.nSize, intrinsic.kSize,
mWarpCount, nWarpCount, mTileCount,
nTileCount, kTileCount};
}
return std::nullopt;
}

} // namespace mlir::iree_compiler
52 changes: 52 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "mlir/IR/Types.h"

namespace mlir::iree_compiler {

/// Struct containing information about a matmul's shape and type.
struct GPUMatmulShapeType {
int64_t mSize;
int64_t nSize;
int64_t kSize;
Type aType;
Type bType;
Type cType;

GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c)
: mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {}
};

/// Struct containing seed tile sizes for GPU MMA heuristics deduction logic.
struct GPUMMAHeuristicSeeds {
// The default number of subgroups to use per workgroup
int64_t numSubgroupsPerWorkgroup;
// The default number of tiles along M/N dimension to use per workgroup
int64_t numMNTilesPerSubgroup;
// The default number of tiles along K dimension to use per subgroup
int64_t numKTilesPerSubgroup;
};

struct GPUMMASchedule {
int64_t mSize; // Native MMA size along M dimension
int64_t nSize; // Native MMA size along N dimension
int64_t kSize; // Native MMA size along K dimension
int64_t mWarpCount; // Number of subgroups along M dimension
int64_t nWarpCount; // Number of subgroups along N dimension
int64_t mTileCount; // Number of tiles per subgroup along M dimension
int64_t nTileCount; // Number of tiles per subgroup along N dimension
int64_t kTileCount; // Number of tiles along K dimension
};

/// Returns a schedule for using one of the given MMA |intrinsics| to target the
/// input |problem|. Returns std::nullopt if we cannot find such a schedule.
std::optional<GPUMMASchedule>
deduceMMASchedule(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
const GPUMMAHeuristicSeeds &seeds);

} // namespace mlir::iree_compiler
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Common",
"//compiler/src/iree/compiler/Codegen/Common:TransformDialectInterpreterPass",
"//compiler/src/iree/compiler/Codegen/Common/GPU:CommonGPUPasses",
"//compiler/src/iree/compiler/Codegen/Common/GPU:GPUHeuristics",
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
"//compiler/src/iree/compiler/Codegen/TransformStrategies/GPU",
"//compiler/src/iree/compiler/Codegen/Transforms",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ iree_cc_library(
MLIRVectorTransforms
iree::compiler::Codegen::Common
iree::compiler::Codegen::Common::GPU::CommonGPUPasses
iree::compiler::Codegen::Common::GPU::GPUHeuristics
iree::compiler::Codegen::Common::TransformDialectInterpreterPass
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
iree::compiler::Codegen::TransformStrategies::GPU
Expand Down
Loading

0 comments on commit 56725c5

Please sign in to comment.