Skip to content

Commit

Permalink
added register pressure estimate to be explicit when needed
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger committed Sep 18, 2024
1 parent 9dd7c4a commit 359e095
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 162 deletions.
4 changes: 3 additions & 1 deletion src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,9 @@ template <>
GenOpMix getGenOpMix<ONNXRoundOp>(Type t, Operation *op) {
return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2},
{GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2}};
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
4 /* Little parallelism in code. */}};
}

template <>
Expand Down
18 changes: 12 additions & 6 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,22 +662,28 @@ int64_t computeSuitableUnrollFactor(MemRefType memRefType,
return 1;
}
// Gather operation statics
int64_t vectorizedOpNum, scalarOpNum;
double avgVL = VectorMachineSupport::getAvgArchVectorLength(
genOps, elementType, vectorizedOpNum, scalarOpNum);
int64_t vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure;
double avgVL =
VectorMachineSupport::getAvgArchVectorLength(genOps, elementType,
vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure);
if (avgVL < 1.5) {
LLVM_DEBUG(llvm::dbgs() << " simd disabled: too few SIMD operations with "
<< avgVL << " avg VL\n");
return 1;
}
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL << "\n");
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL
<< ", vec op num " << vectorizedOpNum
<< ", max reg pressure "
<< estimatedMaxVectorRegisterPressure << "\n");

// Define a target max unroll as a function of register pressure.
int64_t unrollVL;
int64_t vrNum = VectorMachineSupport::getArchVectorRegisterNum();
if (vectorizedOpNum >= vrNum / 2)
if (estimatedMaxVectorRegisterPressure >= vrNum)
unrollVL = 1;
else if (estimatedMaxVectorRegisterPressure * 2 >= vrNum)
unrollVL = 2;
else if (vectorizedOpNum >= vrNum / 4)
else if (estimatedMaxVectorRegisterPressure * 4 >= vrNum)
unrollVL = 4;
else
unrollVL = 8;
Expand Down
6 changes: 4 additions & 2 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5},
{GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2},
{GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2}};
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
8 /* Little parallelism in code. */}};
totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/,
innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount,
simdOnly);
Expand Down Expand Up @@ -83,7 +85,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
adjustX = create.math.add(roundX, zeroPoint);
else
adjustX = roundX;
// Saturate
// Saturate: use max into a min.
Value saturateX = create.math.clip(adjustX, qMin, qMax);
Value res = create.math.cast(quantizedElementType, saturateX);
return res;
Expand Down
69 changes: 47 additions & 22 deletions src/Dialect/Mlir/VectorMachineSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,31 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
}

/*static*/ double VectorMachineSupport::getAvgArchVectorLength(GenOpMix &genOps,
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) {
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum,
int64_t &maxVectorRegisterPressure) {
int64_t size = genOps.size();
if (!hasSimd()) {
vectorizedOpNum = 0;
vectorizedOpNum = maxVectorRegisterPressure = 0;
scalarOpNum = size;
return 1;
}
int64_t totProcessedValues = 0.0;
vectorizedOpNum = 0;
vectorizedOpNum = maxVectorRegisterPressure = 0;
scalarOpNum = 0;
bool hasRegisterPressure = false;

// Determine which operations support SIMD and accumulate their vector
// lengths.
for (auto pair : genOps) {
GenericOps genOp = pair.first;
int64_t num = pair.second;
// Handle other metrics first.
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
maxVectorRegisterPressure = std::max(maxVectorRegisterPressure, num);
hasRegisterPressure = true;
continue;
}
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t vl = getArchVectorLength(genOp, elementType);
// If past last value, assume 1; otherwise use actual value.
// Accumulate weighted scalar/vectorized num and vl length.
Expand All @@ -107,6 +117,10 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
// Compute final values
int64_t totNum = vectorizedOpNum + scalarOpNum;
scalarOpNum = size - vectorizedOpNum;
if (!hasRegisterPressure) {
// Estimate default register pressure as one per 2 vector operation.
maxVectorRegisterPressure = std::max(vectorizedOpNum / 2, (int64_t)1);
}
return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0;
}

Expand All @@ -115,13 +129,13 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
// =============================================================================

int64_t Z16VectorMachineSupport::computeArchVectorLength(
GenericOps Gop, Type elementType) {
GenericOps genOp, Type elementType) {
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
bool isFloat = mlir::isa<FloatType>(elementType);

// Support shared between int and float.
switch (Gop) {
switch (genOp) {
case GenericOps::ScalarOnlyGop:
return 1; // Must be scalar.
case GenericOps::SelectGop:
Expand All @@ -137,10 +151,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
// Supports only 32 and 64 bit Floats; There is support for extended too
// but ignore this for now.
if (!(bitWidth == 32 || bitWidth == 64 ||
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
return UNSUPPORTED;
// Now we have a supported length, test for specific operations.
switch (Gop) {
switch (genOp) {
case GenericOps::AbsGop: /* Supported via compare and select */
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::CeilGop: /* Use load integer & rounding modes*/
Expand All @@ -161,7 +175,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
}
}
// Support for integer (we consider bit-wide ops as byte wide ops).
switch (Gop) {
switch (genOp) {
// 1 - 16 byte operations.
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::ConversionGop:
Expand Down Expand Up @@ -190,13 +204,14 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
// =============================================================================

int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
GenericOps Gop, Type elementType) {
GenericOps genOp, Type elementType) {
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
bool isFloat = mlir::isa<FloatType>(elementType);

// Support shared between int and float.
switch (Gop) {
switch (genOp) {
case GenericOps::ScalarOnlyGop:
return 1; // Must be scalar.
case GenericOps::SelectGop:
Expand All @@ -212,10 +227,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
// Supports only 32 and 64 bit Floats; There is support for extended too
// but ignore this for now.
if (!(bitWidth == 32 || bitWidth == 64 ||
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
return UNSUPPORTED;
// Now we have a supported length, test for specific operations.
switch (Gop) {
switch (genOp) {
case GenericOps::AbsGop:
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::CeilGop:
Expand All @@ -237,7 +252,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
}
}
// Support for integer (we consider bit-wide ops as byte wide ops).
switch (Gop) {
switch (genOp) {
// 1 - 16 byte operations.
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::ConversionGop:
Expand Down Expand Up @@ -276,13 +291,14 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
// =============================================================================

int64_t NeonVectorMachineSupport::computeArchVectorLength(
GenericOps Gop, Type elementType) {
GenericOps genOp, Type elementType) {
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
bool isFloat = mlir::isa<FloatType>(elementType);

// Support shared between int and float.
switch (Gop) {
switch (genOp) {
case GenericOps::ScalarOnlyGop:
return 1; // Must be scalar.
case GenericOps::SelectGop:
Expand All @@ -297,10 +313,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
if (isFloat) {
// Supports only 32 and 64 bit Floats;
if (!(bitWidth == 32 || bitWidth == 64 ||
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
return UNSUPPORTED;
// Now we have a supported length, test for specific operations.
switch (Gop) {
switch (genOp) {
case GenericOps::AbsGop:
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::CeilGop:
Expand All @@ -322,7 +338,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
}
}
// Support for integer (we consider bit-wide ops as byte wide ops).
switch (Gop) {
switch (genOp) {
// 1 - 16 byte operations.
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::ConversionGop:
Expand Down Expand Up @@ -370,10 +386,19 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) {
for (auto pair : mix1) {
GenericOps genOp = pair.first;
int64_t num = pair.second;
if (u.find(genOp) != u.end())
u[genOp] += num; // Has this op already, add to it.
else
if (u.find(genOp) != u.end()) {
// Merge the 2 operation counts/metrics.
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
// For register pressure, pick the max of both.
u[genOp] = std::max(u[genOp], num);
} else {
// For operation count, use the sum of both
u[genOp] += num;
}
} else {
// First time we have this.
u[genOp] = num;
}
}
return u;
}
Expand Down
21 changes: 20 additions & 1 deletion src/Dialect/Mlir/VectorMachineSupport.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ namespace onnx_mlir {
// (e.g. all the compares).

enum class GenericOps {
/////////////////////////////////////
// Generic ops.
/////////////////////////////////////

AbsGop,
ArithmeticGop, /* Simple compute ops: add/sub/neg + ops of same complexity. */
CeilDivGop,
Expand Down Expand Up @@ -62,6 +66,17 @@ enum class GenericOps {
TrigArcGop, /* Arc trigonometry ops: asin, acos, atan. */
TrigGop, /* Trigonometry ops: sin, cos, tan. */
TrigHyperbolicGop, /* Hyperbolic trig. */

LastGop, /* Marker of the last op. Used to delineate from other metrics. */

/////////////////////////////////////
// Metrics others than operations.
/////////////////////////////////////

// Metric that provides an estimate of the maximum number of vector registers
// used in a kernel. If none is provided, we estimate the pressure based on
// the number of operations.
EstimatedVectorRegisterPressure,
};

// Describe the mix of Generic operations in a given kernel. Each generic
Expand Down Expand Up @@ -132,8 +147,12 @@ class VectorMachineSupport {
// number of times that generic operation was found. Note that scalar
// operation have a vector length of one in the weighted average as they still
// contribute one result.
// Max vector register pressure is also reported, either from an explicit
// mention in the genOps, or estimated as one vector register per vector
// operation.
static double getAvgArchVectorLength(GenOpMix &genOps, mlir::Type elementType,
int64_t &vectorizedOpNum, int64_t &scalarOpNum);
int64_t &vectorizedOpNum, int64_t &scalarOpNum,
int64_t &maxVectorRegisterPressure);

protected:
// Virtual functions that do the actual work. Called by the "get" functions.
Expand Down
Loading

0 comments on commit 359e095

Please sign in to comment.