Skip to content

Commit

Permalink
try 2 loops for quantize linear
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandre Eichenberger <alexe@us.ibm.com>
  • Loading branch information
AlexandreEichenberger committed Sep 19, 2024
1 parent bfd6542 commit 60a8b9e
Showing 1 changed file with 38 additions and 2 deletions.
40 changes: 38 additions & 2 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/,
innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount,
simdOnly);
MemRefType outputType = llvm::cast<MemRefType>(alloc.getType());
totVL = boostVLForMinUnroll(inputType, outputType, totVL);
}

IndexExpr zero = LitIE(0);
Expand All @@ -78,6 +76,43 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
inputAF.emplace_back(zero);
DimsExpr outputAF;
outputAF.emplace_back(zero);

#if 1
// Allocate output buffers.
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims);
DimsExpr bufferAF;
bufferAF.emplace_back(zero);

create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatBuffer}, {bufferAF},
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
Value x = inputVals[0];
// Scale
Value scaleX = create.math.div(x, scale);
// Round
Value roundX = create.math.round(scaleX);
// Adjust
Value adjustX;
if (hasZeroPoint)
adjustX = create.math.add(roundX, zeroPoint);
else
adjustX = roundX;
// Saturate: use max into a min.
Value saturateX = create.math.clip(adjustX, qMin, qMax);
return saturateX;
}});
create.krnl.forLoopIE(simdLb, simdUb, 1, /*parallel*/ false,
[&](KrnlBuilder &kb, ValueRange loopInd) {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(kb);
Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]});
Value res = create.math.cast(quantizedElementType, buffVal);
create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]});
});
#else
MemRefType outputType = llvm::cast<MemRefType>(alloc.getType());
totVL = boostVLForMinUnroll(inputType, outputType, totVL);
create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
Expand All @@ -98,6 +133,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
Value res = create.math.cast(quantizedElementType, saturateX);
return res;
}});
#endif
if (totVL > 1)
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
simdLoopStaticTripCount, "quantizationLinear whole tensor");
Expand Down

0 comments on commit 60a8b9e

Please sign in to comment.