Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added explicit register pressure estimate for SIMD and tuned [Dynamic]LinearQuantization operations #2945

Merged
4 changes: 3 additions & 1 deletion src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1289,7 +1289,9 @@ template <>
GenOpMix getGenOpMix<ONNXRoundOp>(Type t, Operation *op) {
return {{GenericOps::ArithmeticGop, 4}, {GenericOps::MulGop, 2},
{GenericOps::CompareGop, 3}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2}};
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
4 /* Little parallelism in code. */}};
}

template <>
Expand Down
34 changes: 28 additions & 6 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,22 +662,28 @@ int64_t computeSuitableUnrollFactor(MemRefType memRefType,
return 1;
}
// Gather operation statics
int64_t vectorizedOpNum, scalarOpNum;
double avgVL = VectorMachineSupport::getAvgArchVectorLength(
genOps, elementType, vectorizedOpNum, scalarOpNum);
int64_t vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure;
double avgVL =
VectorMachineSupport::getAvgArchVectorLength(genOps, elementType,
vectorizedOpNum, scalarOpNum, estimatedMaxVectorRegisterPressure);
if (avgVL < 1.5) {
LLVM_DEBUG(llvm::dbgs() << " simd disabled: too few SIMD operations with "
<< avgVL << " avg VL\n");
return 1;
}
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL << "\n");
LLVM_DEBUG(llvm::dbgs() << " simd enable: avg vl " << avgVL
<< ", vec op num " << vectorizedOpNum
<< ", max reg pressure "
<< estimatedMaxVectorRegisterPressure << "\n");

// Define a target max unroll as a function of register pressure.
int64_t unrollVL;
int64_t vrNum = VectorMachineSupport::getArchVectorRegisterNum();
if (vectorizedOpNum >= vrNum / 2)
if (estimatedMaxVectorRegisterPressure >= vrNum)
unrollVL = 1;
else if (estimatedMaxVectorRegisterPressure * 2 >= vrNum)
unrollVL = 2;
else if (vectorizedOpNum >= vrNum / 4)
else if (estimatedMaxVectorRegisterPressure * 4 >= vrNum)
unrollVL = 4;
else
unrollVL = 8;
Expand Down Expand Up @@ -743,6 +749,22 @@ int64_t capVLForMaxUnroll(
return archVL * unrollVL;
}

int64_t boostVLForMinUnroll(
MemRefType memRefType, MemRefType convertedMemRefType, int64_t totVL) {
if (totVL == 1)
return 1; // Simd already disabled, nothing to cap.
Type convertedElementType = convertedMemRefType.getElementType();
int64_t convertedArchVL =
VectorMachineSupport::getArchVectorLength(convertedElementType);
if (convertedArchVL > totVL) {
LLVM_DEBUG(llvm::dbgs()
<< " simd enable: boost totVL to " << convertedArchVL
<< " because of type conversions.\n");
return convertedArchVL;
}
return totVL;
}

int64_t capVLForSimdOnly(
MemRefType memRefType, int64_t totVL, int64_t simdLoopStaticTripCount) {
if (totVL == 1)
Expand Down
6 changes: 6 additions & 0 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,12 @@ int64_t computeSuitableUnrollFactor(mlir::MemRefType memRefType,
// Cap totVL so that it is at most maxUnrollVL * archVL.
int64_t capVLForMaxUnroll(
mlir::MemRefType memRefType, int64_t totVL, int64_t maxUnrollVL);
// In some type conversion loops we may have a given totVL based on a given
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also needed for some hybrid loops where multiple precisions where present in the same SIMD loop. Currently not needed in the (best performing) final version, but would be useful to similar situations in the future.

// memRef type and gen op mix. But the final result may be converted to a
// different type, which may requires a minimum unroll to proceed as a single
// SIMD operation. This call adjust the totVL for that case.
int64_t boostVLForMinUnroll(mlir::MemRefType memRefType,
mlir::MemRefType convertedMemRefType, int64_t totVL);
// Enabling a simdOnly code generation scheme by capping totVL so that it
// divides simdLoopStaticTripCount. When not possible (either because
// there is no totVL that divides simdLoopStaticTripCount or trip count is
Expand Down
44 changes: 37 additions & 7 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax,
Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD,
bool enableParallel) {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(
rewriter, loc);
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, VectorBuilder, MathBuilder>
create(rewriter, loc);

// Types
Type quantizedElementType = quantizedType.getElementType();
Expand All @@ -54,7 +54,9 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
GenOpMix mix = {{GenericOps::DivGop, 1}, {GenericOps::ArithmeticGop, 5},
{GenericOps::ConversionGop, 1}, {GenericOps::MinMaxGop, 2},
{GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3},
{GenericOps::FloorGop, 2}};
{GenericOps::FloorGop, 2},
{GenericOps::EstimatedVectorRegisterPressure,
8 /* Little parallelism in code. */}};
totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/,
innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount,
simdOnly);
Expand All @@ -68,8 +70,16 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
inputAF.emplace_back(zero);
DimsExpr outputAF;
outputAF.emplace_back(zero);

// faster than original loop on z16, takes 124us for 64k vals
// Allocate output buffers.
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims);
DimsExpr bufferAF;
bufferAF.emplace_back(zero);

create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
{flatInput}, {inputAF}, {flatBuffer}, {bufferAF},
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
Value x = inputVals[0];
Expand All @@ -83,11 +93,31 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
adjustX = create.math.add(roundX, zeroPoint);
else
adjustX = roundX;
// Saturate
// Saturate: use max into a min.
Value saturateX = create.math.clip(adjustX, qMin, qMax);
Value res = create.math.cast(quantizedElementType, saturateX);
return res;
// Old approach.
// return create.math.cast(quantizedElementType, saturateX);
return saturateX;
}});

// A second loop that performs scalar float to int performs better than the
// compiler's attempt to generate SIMD conversion code. This might not hold
// with all data types, but is definitely noticeable with uint8.
//
// Investigate further: we might save the vector to a buffer on the fly
// (avoiding a second loop as below), and then reload each value as scalar and
// then saved them as scalar (thus avoiding the insert/extract SIMD operations
// that also do not perform well). We can have a SIMD buffer in memory for the
// non-quantized and quantized simd values, but then we also need to privatize
// it, which is also not easy in this scheme. So ignore this for now.
create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel,
[&](KrnlBuilder &kb, ValueRange loopInd) {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(kb);
Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]});
Value res = create.math.cast(quantizedElementType, buffVal);
create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]});
});

if (totVL > 1)
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
simdLoopStaticTripCount, "quantizationLinear whole tensor");
Expand Down
23 changes: 23 additions & 0 deletions src/Dialect/Mlir/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2073,6 +2073,29 @@ void VectorBuilder::multiReduction(ArrayRef<Value> inputVecArray,
}
}

Value VectorBuilder::extractElement(Value vector, int64_t index) const {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added for experiments on how to get better performance, currently not in use but were verified to work during experiments associated with the PR

MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
VectorType type = llvm::cast<VectorType>(vector.getType());
int64_t VL = type.getShape()[0];
assert(type.getRank() == 1 && "expected 1D vector only");
assert(index >= 0 && index < VL && "out of range vector index");
Value position = create.math.constantIndex(index);
return b().create<vector::ExtractElementOp>(loc(), vector, position);
}

Value VectorBuilder::insertElement(
Value vector, Value element, int64_t index) const {
MultiDialectBuilder<VectorBuilder, MathBuilder> create(*this);
VectorType type = llvm::cast<VectorType>(vector.getType());
int64_t VL = type.getShape()[0];
assert(type.getRank() == 1 && "expected 1D vector only");
assert(index >= 0 && index < VL && "out of range vector index");
Value position = create.math.constantIndex(index);
// Unlike LLVM insert element which takes <dest, source, position>, vector
// take <source, dest, position>
return b().create<vector::InsertElementOp>(loc(), element, vector, position);
}

//===----------------------------------------------------------------------===//
// LLVM Builder
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions src/Dialect/Mlir/DialectBuilder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,11 @@ struct VectorBuilder final : DialectBuilder {
void multiReduction(mlir::ArrayRef<mlir::Value> inputVecArray,
F2 reductionFct, llvm::SmallVectorImpl<mlir::Value> &outputVecArray);

// Insert and extract.
mlir::Value extractElement(mlir::Value vector, int64_t position) const;
mlir::Value insertElement(
mlir::Value vector, mlir::Value element, int64_t position) const;

private:
bool isPowerOf2(uint64_t num) const;
uint64_t getLengthOf1DVector(mlir::Value vec) const;
Expand Down
69 changes: 46 additions & 23 deletions src/Dialect/Mlir/VectorMachineSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,30 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
}

/*static*/ double VectorMachineSupport::getAvgArchVectorLength(GenOpMix &genOps,
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum) {
Type elementType, int64_t &vectorizedOpNum, int64_t &scalarOpNum,
int64_t &maxVectorRegisterPressure) {
int64_t size = genOps.size();
vectorizedOpNum = maxVectorRegisterPressure = 0;
if (!hasSimd()) {
vectorizedOpNum = 0;
scalarOpNum = size;
return 1;
}
int64_t totProcessedValues = 0.0;
vectorizedOpNum = 0;
scalarOpNum = 0;
bool hasRegisterPressure = false;

// Determine which operations support SIMD and accumulate their vector
// lengths.
for (auto pair : genOps) {
GenericOps genOp = pair.first;
int64_t num = pair.second;
// Handle other metrics first.
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
maxVectorRegisterPressure = std::max(maxVectorRegisterPressure, num);
hasRegisterPressure = true;
continue;
}
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t vl = getArchVectorLength(genOp, elementType);
// If past last value, assume 1; otherwise use actual value.
// Accumulate weighted scalar/vectorized num and vl length.
Expand All @@ -106,7 +115,10 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
}
// Compute final values
int64_t totNum = vectorizedOpNum + scalarOpNum;
scalarOpNum = size - vectorizedOpNum;
if (!hasRegisterPressure) {
// Estimate default register pressure as one per 2 vector operation.
maxVectorRegisterPressure = std::max(vectorizedOpNum / 2, (int64_t)1);
}
return totNum != 0 ? (1.0 * totProcessedValues) / (1.0 * totNum) : 1.0;
}

Expand All @@ -115,13 +127,13 @@ int64_t VectorMachineSupport::computeArchVectorLength(Type elementType) {
// =============================================================================

int64_t Z16VectorMachineSupport::computeArchVectorLength(
GenericOps Gop, Type elementType) {
GenericOps genOp, Type elementType) {
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
bool isFloat = mlir::isa<FloatType>(elementType);

// Support shared between int and float.
switch (Gop) {
switch (genOp) {
case GenericOps::ScalarOnlyGop:
return 1; // Must be scalar.
case GenericOps::SelectGop:
Expand All @@ -137,10 +149,10 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
// Supports only 32 and 64 bit Floats; There is support for extended too
// but ignore this for now.
if (!(bitWidth == 32 || bitWidth == 64 ||
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
return UNSUPPORTED;
// Now we have a supported length, test for specific operations.
switch (Gop) {
switch (genOp) {
case GenericOps::AbsGop: /* Supported via compare and select */
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::CeilGop: /* Use load integer & rounding modes*/
Expand All @@ -161,7 +173,7 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
}
}
// Support for integer (we consider bit-wide ops as byte wide ops).
switch (Gop) {
switch (genOp) {
// 1 - 16 byte operations.
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::ConversionGop:
Expand Down Expand Up @@ -190,13 +202,14 @@ int64_t Z16VectorMachineSupport::computeArchVectorLength(
// =============================================================================

int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
GenericOps Gop, Type elementType) {
GenericOps genOp, Type elementType) {
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
bool isFloat = mlir::isa<FloatType>(elementType);

// Support shared between int and float.
switch (Gop) {
switch (genOp) {
case GenericOps::ScalarOnlyGop:
return 1; // Must be scalar.
case GenericOps::SelectGop:
Expand All @@ -212,10 +225,10 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
// Supports only 32 and 64 bit Floats; There is support for extended too
// but ignore this for now.
if (!(bitWidth == 32 || bitWidth == 64 ||
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
return UNSUPPORTED;
// Now we have a supported length, test for specific operations.
switch (Gop) {
switch (genOp) {
case GenericOps::AbsGop:
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::CeilGop:
Expand All @@ -237,7 +250,7 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
}
}
// Support for integer (we consider bit-wide ops as byte wide ops).
switch (Gop) {
switch (genOp) {
// 1 - 16 byte operations.
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::ConversionGop:
Expand Down Expand Up @@ -276,13 +289,14 @@ int64_t SSE42x86VectorMachineSupport::computeArchVectorLength(
// =============================================================================

int64_t NeonVectorMachineSupport::computeArchVectorLength(
GenericOps Gop, Type elementType) {
GenericOps genOp, Type elementType) {
assert(genOp < GenericOps::LastGop && "no metrics here, only genOps");
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
int64_t archVL = VectorMachineSupport::getArchVectorLength(elementType);
bool isFloat = mlir::isa<FloatType>(elementType);

// Support shared between int and float.
switch (Gop) {
switch (genOp) {
case GenericOps::ScalarOnlyGop:
return 1; // Must be scalar.
case GenericOps::SelectGop:
Expand All @@ -297,10 +311,10 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
if (isFloat) {
// Supports only 32 and 64 bit Floats;
if (!(bitWidth == 32 || bitWidth == 64 ||
(bitWidth == 16 && Gop == GenericOps::ConversionGop)))
(bitWidth == 16 && genOp == GenericOps::ConversionGop)))
return UNSUPPORTED;
// Now we have a supported length, test for specific operations.
switch (Gop) {
switch (genOp) {
case GenericOps::AbsGop:
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::CeilGop:
Expand All @@ -322,7 +336,7 @@ int64_t NeonVectorMachineSupport::computeArchVectorLength(
}
}
// Support for integer (we consider bit-wide ops as byte wide ops).
switch (Gop) {
switch (genOp) {
// 1 - 16 byte operations.
case GenericOps::ArithmeticGop: /* Add/sub,... */
case GenericOps::ConversionGop:
Expand Down Expand Up @@ -370,10 +384,19 @@ GenOpMix computeGenOpMixUnion(const GenOpMix &mix1, const GenOpMix &mix2) {
for (auto pair : mix1) {
GenericOps genOp = pair.first;
int64_t num = pair.second;
if (u.find(genOp) != u.end())
u[genOp] += num; // Has this op already, add to it.
else
if (u.find(genOp) != u.end()) {
// Merge the 2 operation counts/metrics.
if (genOp == GenericOps::EstimatedVectorRegisterPressure) {
// For register pressure, pick the max of both.
u[genOp] = std::max(u[genOp], num);
} else {
// For operation count, use the sum of both
u[genOp] += num;
}
} else {
// First time we have this.
u[genOp] = num;
}
}
return u;
}
Expand Down
Loading
Loading