Skip to content

Commit

Permalink
Added explicit register pressure estimate for SIMD and tuned [Dynamic…
Browse files Browse the repository at this point in the history
…]LinearQuantization operations (#2945)


Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger committed Sep 24, 2024
1 parent 087f069 commit 5c53b7e
Show file tree
Hide file tree
Showing 13 changed files with 710 additions and 500 deletions.
4 changes: 3 additions & 1 deletion src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,9 @@ template <>
GenOpMix getGenOpMix<ONNXRoundOp>(Type t, Operation *op) {
return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2},
{GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2}};
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
4 /* Little parallelism in code. */}};
}

template <>
Expand Down
34 changes: 28 additions & 6 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,22 +662,28 @@ int64_t computeSuitableUnrollFactor(MemRefType memRefType,
return 1;
}
// Gather operation statics
int64_t vectorizedOpNum, scalarOpNum;
double avgVL = VectorMachineSupport::getAvgArchVectorLength(
genOps, elementType, vectorizedOpNum, scalarOpNum);
int64_t vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure;
double avgVL =
VectorMachineSupport::getAvgArchVectorLength(genOps, elementType,
vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure);
if (avgVL < 1.5) {
LLVM_DEBUG(llvm::dbgs() << " simd disabled: too few SIMD operations with "
<< avgVL << " avg VL\n");
return 1;
}
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL << "\n");
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL
<< ", vec op num " << vectorizedOpNum
<< ", max reg pressure "
<< estimatedMaxVectorRegisterPressure << "\n");

// Define a target max unroll as a function of register pressure.
int64_t unrollVL;
int64_t vrNum = VectorMachineSupport::getArchVectorRegisterNum();
if (vectorizedOpNum >= vrNum / 2)
if (estimatedMaxVectorRegisterPressure >= vrNum)
unrollVL = 1;
else if (estimatedMaxVectorRegisterPressure * 2 >= vrNum)
unrollVL = 2;
else if (vectorizedOpNum >= vrNum / 4)
else if (estimatedMaxVectorRegisterPressure * 4 >= vrNum)
unrollVL = 4;
else
unrollVL = 8;
Expand Down Expand Up @@ -743,6 +749,22 @@ int64_t capVLForMaxUnroll(
return archVL * unrollVL;
}

int64_t boostVLForMinUnroll(
MemRefType memRefType, MemRefType convertedMemRefType, int64_t totVL) {
if (totVL == 1)
return 1; // Simd already disabled, nothing to cap.
Type convertedElementType = convertedMemRefType.getElementType();
int64_t convertedArchVL =
VectorMachineSupport::getArchVectorLength(convertedElementType);
if (convertedArchVL > totVL) {
LLVM_DEBUG(llvm::dbgs()
<< " simd enable: boost totVL to " << convertedArchVL
<< " because of type conversions.\n");
return convertedArchVL;
}
return totVL;
}

int64_t capVLForSimdOnly(
MemRefType memRefType, int64_t totVL, int64_t simdLoopStaticTripCount) {
if (totVL == 1)
Expand Down
6 changes: 6 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,12 @@ int64_t computeSuitableUnrollFactor(mlir::MemRefType memRefType,
// Cap totVL so that it is at most maxUnrollVL * archVL.
int64_t capVLForMaxUnroll(
mlir::MemRefType memRefType, int64_t totVL, int64_t maxUnrollVL);
// In some type conversion loops we may have a given totVL based on a given
// memRef type and gen op mix. But the final result may be converted to a
// different type, which may requires a minimum unroll to proceed as a single
// SIMD operation. This call adjust the totVL for that case.
int64_t boostVLForMinUnroll(mlir::MemRefType memRefType,
mlir::MemRefType convertedMemRefType, int64_t totVL);
// Enabling a simdOnly code generation scheme by capping totVL so that it
// divides simdLoopStaticTripCount. When not possible (either because
// there is no totVL that divides simdLoopStaticTripCount or trip count is
Expand Down
44 changes: 37 additions & 7 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax,
Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
bool enableParallel) {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
rewriter, loc);
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, VectorBuilder, MathBuilder>
create(rewriter, loc);

// Types
Type quantizedElementType = quantizedType.getElementType();
Expand All @@ -54,7 +54,9 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5},
{GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2},
{GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2}};
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
8 /* Little parallelism in code. */}};
totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/,
innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount,
simdOnly);
Expand All @@ -68,8 +70,16 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
inputAF.emplace_back(zero);
DimsExpr outputAF;
outputAF.emplace_back(zero);

// faster than original loop on z16, takes 124us for 64k vals
// Allocate output buffers.
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims);
DimsExpr bufferAF;
bufferAF.emplace_back(zero);

create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
{flatInput}, {inputAF}, {flatBuffer}, {bufferAF},
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
Value x = inputVals[0];
Expand All @@ -83,11 +93,31 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
adjustX = create.math.add(roundX, zeroPoint);
else
adjustX = roundX;
// Saturate
// Saturate: use max into a min.
Value saturateX = create.math.clip(adjustX, qMin, qMax);
Value res = create.math.cast(quantizedElementType, saturateX);
return res;
// Old approach.
// return create.math.cast(quantizedElementType, saturateX);
return saturateX;
}});

// A second loop that performs scalar float to int performs better than the
// compiler's attempt to generate SIMD conversion code. This might not hold
// with all data types, but is definitely noticeable with uint8.
//
// Investigate further: we might save the vector to a buffer on the fly
// (avoiding a second loop as below), and then reload each value as scalar and
// then saved them as scalar (thus avoiding the insert/extract SIMD operations
// that also do not perform well). We can have a SIMD buffer in memory for the
// non-quantized and quantized simd values, but then we also need to privatize
// it, which is also not easy in this scheme. So ignore this for now.
create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel,
[&](KrnlBuilder &kb, ValueRange loopInd) {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(kb);
Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]});
Value res = create.math.cast(quantizedElementType, buffVal);
create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]});
});

if (totVL > 1)
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
simdLoopStaticTripCount, "quantizationLinear whole tensor");
Expand Down
23 changes: 23 additions & 0 deletions src/Dialect/Mlir/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2073,6 +2073,29 @@ void VectorBuilder::multiReduction(ArrayRef<Value> inputVecArray,
}
}

Value VectorBuilder::extractElement(Value vector, int64_t index) const {
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
VectorType type = llvm::cast<VectorType>(vector.getType());
int64_t VL = type.getShape()[0];
assert(type.getRank() == 1 && "expected 1D vector only");
assert(index >= 0 && index < VL && "out of range vector index");
Value position = create.math.constantIndex(index);
return b().create<vector::ExtractElementOp>(loc(), vector, position);
}

Value VectorBuilder::insertElement(
Value vector, Value element, int64_t index) const {
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
VectorType type = llvm::cast<VectorType>(vector.getType());
int64_t VL = type.getShape()[0];
assert(type.getRank() == 1 && "expected 1D vector only");
assert(index >= 0 && index < VL && "out of range vector index");
Value position = create.math.constantIndex(index);
// Unlike LLVM insert element which takes <dest, source, position>, vector
// take <source, dest, position>
return b().create<vector::InsertElementOp>(loc(), element, vector, position);
}

//===----------------------------------------------------------------------===//
// LLVM Builder
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions src/Dialect/Mlir/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,11 @@ struct VectorBuilder final : DialectBuilder {
void multiReduction(mlir::ArrayRef<mlir::Value> inputVecArray,
F2 reductionFct, llvm::SmallVectorImpl<mlir::Value> &outputVecArray);

// Insert and extract.
mlir::Value extractElement(mlir::Value vector, int64_t position) const;
mlir::Value insertElement(
mlir::Value vector, mlir::Value element, int64_t position) const;

private:
bool isPowerOf2(uint64_t num) const;
uint64_t getLengthOf1DVector(mlir::Value vec) const;
Expand Down
69 changes: 46 additions & 23 deletions src/Dialect/Mlir/VectorMachineSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,30 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
}

/*static*/ double VectorMachineSupport::getAvgArchVectorLength(GenOpMix &genOps,
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) {
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum,
int64_t &maxVectorRegisterPressure) {
int64_t size = genOps.size();
vectorizedOpNum = maxVectorRegisterPressure = 0;
if (!hasSimd()) {
vectorizedOpNum = 0;
scalarOpNum = size;
return 1;
}
int64_t totProcessedValues = 0.0;
vectorizedOpNum = 0;
scalarOpNum = 0;
bool hasRegisterPressure = false;

// Determine which operations support SIMD and accumulate their vector
// lengths.
for (auto pair : genOps) {
GenericOps genOp = pair.first;
int64_t num = pair.second;
// Handle other metrics first.
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
maxVectorRegisterPressure = std::max(maxVectorRegisterPressure, num);
hasRegisterPressure = true;
continue;
}
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t vl = getArchVectorLength(genOp, elementType);
// If past last value, assume 1; otherwise use actual value.
// Accumulate weighted scalar/vectorized num and vl length.
Expand All @@ -106,7 +115,10 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
}
// Compute final values
int64_t totNum = vectorizedOpNum + scalarOpNum;
scalarOpNum = size - vectorizedOpNum;
if (!hasRegisterPressure) {
// Estimate default register pressure as one per 2 vector operation.
maxVectorRegisterPressure = std::max(vectorizedOpNum / 2, (int64_t)1);
}
return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0;
}

Expand All @@ -115,13 +127,13 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
// =============================================================================

int64_t Z16VectorMachineSupport::computeArchVectorLength(
GenericOps Gop, Type elementType) {
GenericOps genOp, Type elementType) {
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
bool isFloat = mlir::isa<FloatType>(elementType);

// Support shared between int and float.
switch (Gop) {
switch (genOp) {
case GenericOps::ScalarOnlyGop:
return 1; // Must be scalar.
case GenericOps::SelectGop:
Expand All @@ -137,10 +149,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
// Supports only 32 and 64 bit Floats; There is support for extended too
// but ignore this for now.
if (!(bitWidth == 32 || bitWidth == 64 ||
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
return UNSUPPORTED;
// Now we have a supported length, test for specific operations.
switch (Gop) {
switch (genOp) {
case GenericOps::AbsGop: /* Supported via compare and select */
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::CeilGop: /* Use load integer & rounding modes*/
Expand All @@ -161,7 +173,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
}
}
// Support for integer (we consider bit-wide ops as byte wide ops).
switch (Gop) {
switch (genOp) {
// 1 - 16 byte operations.
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::ConversionGop:
Expand Down Expand Up @@ -190,13 +202,14 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
// =============================================================================

int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
GenericOps Gop, Type elementType) {
GenericOps genOp, Type elementType) {
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
bool isFloat = mlir::isa<FloatType>(elementType);

// Support shared between int and float.
switch (Gop) {
switch (genOp) {
case GenericOps::ScalarOnlyGop:
return 1; // Must be scalar.
case GenericOps::SelectGop:
Expand All @@ -212,10 +225,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
// Supports only 32 and 64 bit Floats; There is support for extended too
// but ignore this for now.
if (!(bitWidth == 32 || bitWidth == 64 ||
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
return UNSUPPORTED;
// Now we have a supported length, test for specific operations.
switch (Gop) {
switch (genOp) {
case GenericOps::AbsGop:
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::CeilGop:
Expand All @@ -237,7 +250,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
}
}
// Support for integer (we consider bit-wide ops as byte wide ops).
switch (Gop) {
switch (genOp) {
// 1 - 16 byte operations.
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::ConversionGop:
Expand Down Expand Up @@ -276,13 +289,14 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
// =============================================================================

int64_t NeonVectorMachineSupport::computeArchVectorLength(
GenericOps Gop, Type elementType) {
GenericOps genOp, Type elementType) {
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
bool isFloat = mlir::isa<FloatType>(elementType);

// Support shared between int and float.
switch (Gop) {
switch (genOp) {
case GenericOps::ScalarOnlyGop:
return 1; // Must be scalar.
case GenericOps::SelectGop:
Expand All @@ -297,10 +311,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
if (isFloat) {
// Supports only 32 and 64 bit Floats;
if (!(bitWidth == 32 || bitWidth == 64 ||
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
return UNSUPPORTED;
// Now we have a supported length, test for specific operations.
switch (Gop) {
switch (genOp) {
case GenericOps::AbsGop:
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::CeilGop:
Expand All @@ -322,7 +336,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
}
}
// Support for integer (we consider bit-wide ops as byte wide ops).
switch (Gop) {
switch (genOp) {
// 1 - 16 byte operations.
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::ConversionGop:
Expand Down Expand Up @@ -370,10 +384,19 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) {
for (auto pair : mix1) {
GenericOps genOp = pair.first;
int64_t num = pair.second;
if (u.find(genOp) != u.end())
u[genOp] += num; // Has this op already, add to it.
else
if (u.find(genOp) != u.end()) {
// Merge the 2 operation counts/metrics.
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
// For register pressure, pick the max of both.
u[genOp] = std::max(u[genOp], num);
} else {
// For operation count, use the sum of both
u[genOp] += num;
}
} else {
// First time we have this.
u[genOp] = num;
}
}
return u;
}
Expand Down
Loading

0 comments on commit 5c53b7e

Please sign in to comment.