Skip to content

Commit

Permalink
Matmul vector (llvm#1246)
Browse files Browse the repository at this point in the history
Enables better simd code generation for matrix-vector multiplication

Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
Co-authored-by: Kevin O'Brien <caomhin@us.ibm.com>
Co-authored-by: Ettore Tiotto <etiotto@ca.ibm.com>
  • Loading branch information
3 people committed Mar 30, 2022
1 parent 53128f7 commit 5ecd858
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 100 deletions.
113 changes: 74 additions & 39 deletions src/Conversion/KrnlToAffine/KrnlMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

//===-------------- KrnlMatmul.cpp - Lower KrnlMatmulOp -------------------===//
//
// Copyright 2019-2020 The IBM Research Authors.
// Copyright 2019-2022 The IBM Research Authors.
//
// =============================================================================
//
Expand All @@ -27,7 +27,7 @@

#define DEBUG_TYPE "krnl_to_affine"

#define ENABLE_MAT_VECT_MUL 1
static constexpr int32_t DISABLE_MAT_VEC_PRODUCT = 0;

using namespace mlir;
using namespace onnx_mlir;
Expand Down Expand Up @@ -139,7 +139,7 @@ class KrnlMatmulLowering : public ConversionPattern {
kGlobalUB(operandAdaptor.kGlobalUB());

// Has a matrix times vector when the J upper bound is literal 1.
bool matVectorProduct = ENABLE_MAT_VECT_MUL && jGlobalUB.isLiteral() &&
bool matVectorProduct = !DISABLE_MAT_VEC_PRODUCT && jGlobalUB.isLiteral() &&
jGlobalUB.getLiteral() == 1;

// Investigate SIMD
Expand All @@ -152,9 +152,11 @@ class KrnlMatmulLowering : public ConversionPattern {
if (iComputeTileSize.isLiteral() && kComputeTileSize.isLiteral()) {
uint64_t i = iComputeTileSize.getLiteral();
uint64_t k = kComputeTileSize.getLiteral();
// TODO: longer I & K vectors: (i % k == 0 && (k & (k - 1)) == 0)
if (i == k && k == 4) {
vectorLen = kComputeTileSize;
VectorBuilder createVect(createAffine);
uint64_t mVL = createVect.getMachineVectorLength(elementType);
if (i % mVL == 0 && k % mVL == 0) {
// Right now, vector length must be mVL.
vectorLen = LiteralIndexExpr(mVL);
} else {
simdize = false;
LLVM_DEBUG(llvm::dbgs() << "Matmul: mat*vec with bad sizes: i " << i
Expand Down Expand Up @@ -383,50 +385,83 @@ class KrnlMatmulLowering : public ConversionPattern {
ArrayRef<IndexExpr> cStart, IndexExpr I, IndexExpr J, IndexExpr K,
IndexExpr vectorLen, bool unrollJam) const {
// can simdize only if I & K is compile time
assert(I.isLiteral() && K.isLiteral() &&
assert(I.isLiteral() && K.isLiteral() && vectorLen.isLiteral() &&
"can only simdize with compile time "
"blocking factor on simd axis");

MultiDialectBuilder<MathBuilder, VectorBuilder, AffineBuilderKrnlMem,
MemRefBuilder>
MemRefBuilder, KrnlBuilder>
create(createAffine);
int64_t iLit(I.getLiteral()), VL(vectorLen.getLiteral());
int64_t mVL = create.vec.getMachineVectorLength(elementType);
// Get operands.
KrnlMatMulOpAdaptor operandAdaptor = KrnlMatMulOpAdaptor(op);
Value A(operandAdaptor.A()), B(operandAdaptor.B()), C(operandAdaptor.C());
int64_t aRank(aStart.size());
int64_t aRank(aStart.size()), bRank(bStart.size()), cRank(cStart.size());

// Generate the vector type conversions.
int64_t VL = vectorLen.getLiteral();
assert(VL == mVL && "vector length and VL must be identical for now");
VectorType vecType = VectorType::get({VL}, elementType);
int64_t iUnrollForReduction = K.getLiteral();

// Iterates over the I indices (K is SIMD dim).
// First compute A[i,k]*B[k, 1] for i=0..iUnrollForReduction explicitly.
// We reuse B[k][0] vector for each iteration of i.
SmallVector<Value, 4> bAccess;
IndexExpr::getValues(bStart, bAccess);
// bAccess = {k=0 + bStart0.getValue(), bStart1.getValue()};
IndexExpr::getValues(bStart, bAccess);
Value vb = create.vec.load(vecType, B, bAccess);
SmallVector<Value, 8> vResList;
// Generate computation for each i, preserving the value in vResList.
for (int64_t i = 0; i < iUnrollForReduction; ++i) {
SmallVector<Value, 4> aAccess;
IndexExpr::getValues(aStart, aAccess);
LiteralIndexExpr iVal(i);
aAccess[aRank - 2] = create.math.add(aAccess[aRank - 2], iVal.getValue());
Value va = create.vec.load(vecType, A, aAccess);
Value vres = create.math.mul(va, vb);
vResList.emplace_back(vres);
int64_t iUnrollFactor = iLit;
assert(iUnrollFactor % VL == 0 && "i blocking should be a multiple of VL");

// Have to privatize CTmpType by unroll factor.
MemRefType CTmpType = MemRefType::get({iUnrollFactor}, vecType);
assert(BUFFER_ALIGN >= gDefaultAllocAlign &&
"alignment of buffers cannot be smaller than the default alignment "
"(which is set for SIMD correctness");
Value TmpProd = create.mem.alignedAlloca(CTmpType, BUFFER_ALIGN);
// Init with zero.
Value fZero = create.math.constant(elementType, 0);
Value vFZero = create.vec.broadcast(vecType, fZero);
create.krnl.memset(TmpProd, vFZero);

LiteralIndexExpr zero(0);
create.affineKMem.forIE(
zero, K, VL, [&](AffineBuilderKrnlMem &createAffine, Value k) {
MultiDialectBuilder<MathBuilder, VectorBuilder> create(createAffine);
// Iterates over the I indices (K is SIMD dim).
// First compute A[i,k]*B[k, 1] for i=0..iUnrollFactor explicitly.
// We reuse B[k][0] vector for each iteration of i.
SmallVector<Value, 4> bAccess;
IndexExpr::getValues(bStart, bAccess);
// bAccess = {k + bStart0.getValue(), bStart1.getValue()};
bAccess[bRank - 2] = create.math.add(k, bAccess[bRank - 2]);
Value vb = create.vec.load(vecType, B, bAccess);
// Generate computation for each i, manually unrolled for simplicity.
for (int64_t i = 0; i < iUnrollFactor; ++i) {
SmallVector<Value, 4> aAccess;
IndexExpr::getValues(aStart, aAccess);
Value iVal = create.math.constantIndex(i);
aAccess[aRank - 2] = create.math.add(aAccess[aRank - 2], iVal);
aAccess[aRank - 1] = create.math.add(k, aAccess[aRank - 1]);
Value va = create.vec.load(vecType, A, aAccess);
Value vTmpProd = create.vec.load(vecType, TmpProd, {iVal});
Value vres = create.vec.fma(va, vb, vTmpProd);
create.vec.store(vres, TmpProd, {iVal});
}
});

// Reduce each SIMD vector of length mVL using a SIMD parallel reduction.
SmallVector<Value, 8> vProdList;
for (int64_t i = 0; i < iUnrollFactor; ++i) {
Value iVal = create.math.constantIndex(i);
Value vTmpProd = create.vec.load(vecType, TmpProd, {iVal});
vProdList.emplace_back(vTmpProd);
}
SmallVector<Value, 8> vReductionList;
create.vec.multiReduction(vProdList, vReductionList);
// For each reduction in the list (vector of VL length), load C, add
// reduction, and store C.
uint64_t size = vReductionList.size();
for (uint64_t i = 0; i < size; ++i) {
SmallVector<Value, 4> cAccess;
IndexExpr::getValues(cStart, cAccess);
Value iVal = create.math.constantIndex(i * VL);
cAccess[cRank - 2] = create.math.add(cAccess[cRank - 2], iVal);
Value vc = create.vec.load(vecType, C, cAccess);
vc = create.math.add(vc, vReductionList[i]);
create.vec.store(vc, C, cAccess);
}
// Reduce each SIMD vector of length VL==K using a SIMD parallel reduction.
Value vReduction = create.vec.multiReduction(vResList);
// Add the reduction to the previous value of C.
SmallVector<Value, 4> cAccess;
IndexExpr::getValues(cStart, cAccess);
Value vc = create.vec.load(vecType, C, cAccess);
vc = create.math.add(vc, vReduction);
create.vec.store(vc, C, cAccess);
}

// Simdize along J / memory rows in B and C.
Expand Down
66 changes: 42 additions & 24 deletions src/Conversion/ONNXToKrnl/Math/MatMul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
using namespace mlir;

#define DEBUG_TYPE "matmul"
static constexpr int32_t DISABLE_MAT_VEC_PRODUCT = 0;

struct ONNXMatMulOpLowering : public ConversionPattern {
ONNXMatMulOpLowering(
Expand Down Expand Up @@ -178,29 +179,42 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
});
}

void computeTileSizeForMatVectProduct(DimIndexExpr dimI, DimIndexExpr dimJ,
DimIndexExpr dimK, int64_t &iRegTile, int64_t &jRegTile,
int64_t &kRegTile, bool &simdize) const {
void computeTileSizeForMatVectProduct(int64_t mVL, DimIndexExpr dimI,
DimIndexExpr dimJ, DimIndexExpr dimK, int64_t &iRegTile,
int64_t &jRegTile, int64_t &kRegTile, bool &simdize) const {

// Default values.
// Right can only tile by 4.
iRegTile = 4; // SIMD dim during multi-reduction.
// Right can only tile i and k by (possibly distinct) multiple of mVL.
iRegTile = 2 * mVL; // SIMD dim during multi-reduction.
jRegTile = 1;
kRegTile = 4; // SIMD dim during multiplication.

if (dimI.isLiteral()) {
int64_t constI = dimI.getLiteral();
if (constI < iRegTile) {
simdize = false;
// Not enough data, can only support i/k reg tile of 4.
}
}
kRegTile = 16 * mVL; // SIMD dim during multiplication.

if (dimK.isLiteral()) {
int64_t constK = dimK.getLiteral();
if (constK < kRegTile) {
simdize = false;
// Register tile in the I Dim is really for the reduction. The
// computations will be further tiled to a multiple of mVL inside
// krnl.matmul.
kRegTile = (constK / mVL) * mVL; // largest multiple
if (kRegTile > 64 * mVL) {
kRegTile = 64 * mVL;
LLVM_DEBUG({ llvm::dbgs() << "MatMul Vec: cap tiling k\n"; });
} else if (kRegTile < mVL) {
// Not enough data, can only support i/k reg tile of 4.
LLVM_DEBUG({ llvm::dbgs() << "MatMul Vec: disable k\n"; });
simdize = false;
kRegTile = 1;
}
}
if (dimI.isLiteral()) {
int64_t constI = dimI.getLiteral();
if (constI < iRegTile) {
iRegTile = (constI / mVL) * mVL; // largest multiple
if (iRegTile < mVL) {
// Not enough data, can only support i/k reg tile of 4.
LLVM_DEBUG({ llvm::dbgs() << "MatMul Vec: disable i\n"; });
simdize = false;
iRegTile = 1;
}
}
}
LLVM_DEBUG({
Expand All @@ -209,16 +223,17 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
});
}

// Handle the cases with 2x2 matrices both for A, B, and C without broadcast.
// Implementation here uses the efficient 1d tiling plus kernel substitution.
// Handle the cases with 2x2 matrices both for A, B, and C without
// broadcast. Implementation here uses the efficient 1d tiling plus kernel
// substitution.
void replace2x2Matmul2d(ONNXMatMulOp &matMulOp,
ONNXMatMulOpAdaptor &operandAdaptor, Type elementType,
ONNXMatMulOpShapeHelper &shapeHelper, Value alloc, Value zeroVal,
ConversionPatternRewriter &rewriter, Location loc) const {
// Prepare: loop bounds and zero
Value A(operandAdaptor.A()), B(operandAdaptor.B()), C(alloc);
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
rewriter, loc);
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder, VectorBuilder>
create(rewriter, loc);
Value zero = create.math.constantIndex(0);
Value I = create.mem.dim(C, 0);
Value J = create.mem.dim(C, 1);
Expand All @@ -231,13 +246,16 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
// Define blocking, with simdization along the j axis.
DimIndexExpr dimI(I), dimJ(J), dimK(K);
int64_t iRegTile, jRegTile, kRegTile;
bool isMatVectorProduct = dimJ.isLiteral() && dimJ.getLiteral() == 1;
if (isMatVectorProduct)
bool isMatVectorProduct =
!DISABLE_MAT_VEC_PRODUCT && dimJ.isLiteral() && dimJ.getLiteral() == 1;
if (isMatVectorProduct) {
int64_t mVL = create.vec.getMachineVectorLength(elementType);
computeTileSizeForMatVectProduct(
dimI, dimJ, dimK, iRegTile, jRegTile, kRegTile, simdize);
else
mVL, dimI, dimJ, dimK, iRegTile, jRegTile, kRegTile, simdize);
} else {
computeTileSizeForMatMatProduct(
dimI, dimJ, dimK, iRegTile, jRegTile, kRegTile, simdize);
}

// I, J, K loop.
ValueRange origLoop = create.krnl.defineLoops(3);
Expand Down
Loading

0 comments on commit 5ecd858

Please sign in to comment.