Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization for Linear Quantization #2954

Merged
merged 12 commits into from
Sep 26, 2024
39 changes: 7 additions & 32 deletions src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
DimsExpr outputAF;
outputAF.emplace_back(zero);

// faster than original loop on z16, takes 124us for 64k vals
// Allocate output buffers.
MemRefType flatBufferType = llvm::cast<MemRefType>(flatInput.getType());
Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims);
DimsExpr bufferAF;
bufferAF.emplace_back(zero);

create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel,
{flatInput}, {inputAF}, {flatBuffer}, {bufferAF},
{flatInput}, {inputAF}, {flatAlloc}, {outputAF},
{[&](const KrnlBuilder &kb, ArrayRef<Value> inputVals, int64_t VL) {
MultiDialectBuilder<MathBuilder> create(kb);
Value x = inputVals[0];
Expand All @@ -95,29 +88,10 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter,
adjustX = roundX;
// Saturate: use max into a min.
Value saturateX = create.math.clip(adjustX, qMin, qMax);
// Old approach.
// return create.math.cast(quantizedElementType, saturateX);
return saturateX;
// Convert into quantized type.
return create.math.cast(quantizedElementType, saturateX);
}});

// A second loop that performs scalar float to int performs better than the
// compiler's attempt to generate SIMD conversion code. This might not hold
// with all data types, but is definitely noticeable with uint8.
//
// Investigate further: we might save the vector to a buffer on the fly
// (avoiding a second loop as below), and then reload each value as scalar and
// then saved them as scalar (thus avoiding the insert/extract SIMD operations
// that also do not perform well). We can have a SIMD buffer in memory for the
// non-quantized and quantized simd values, but then we also need to privatize
// it, which is also not easy in this scheme. So ignore this for now.
create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel,
[&](const KrnlBuilder &kb, ValueRange loopInd) {
MultiDialectBuilder<KrnlBuilder, MemRefBuilder, MathBuilder> create(kb);
Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]});
Value res = create.math.cast(quantizedElementType, buffVal);
create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]});
});

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so now we don't need to use the additional loop for conversion.

if (totVL > 1)
onnxToKrnlSimdReport(op, /*successful*/ true, totVL,
simdLoopStaticTripCount, "quantizationLinear whole tensor");
Expand Down Expand Up @@ -202,9 +176,10 @@ struct ONNXQuantizeLinearOpLowering
hasZeroPoint = true;
}
if (disableQuantZeroPoint) {
// TODO: should we expect to disable hasZeroPoint forcefully, or generate
// an error if we had a zero point? Right now, just forcefully assert we
// have no zero point, i.e. ignore one even if we had a zero point.
// TODO: should we expect to disable hasZeroPoint forcefully, or
// generate an error if we had a zero point? Right now, just forcefully
// assert we have no zero point, i.e. ignore one even if we had a zero
// point.
hasZeroPoint = false;
}
emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType,
Expand Down
40 changes: 40 additions & 0 deletions src/Dialect/Mlir/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,46 @@ Value MathBuilder::cast(Type destType, Value src) const {
LLVM_DEBUG(llvm::dbgs() << "srcType: " << srcType << "\n";
llvm::dbgs() << "destType: " << destType << "\n";);

// Before we process with the actual cast, there is a special case that we
// want to handle here. Cast from float to int that have different width, llvm
// generate better patterns if we first cast from float to int of the same
// width, and then from int to a different size int.
// Skip that optimization if the result is a 1 bit (boolean).
if (mlir::isa<FloatType>(srcElemType) &&
mlir::isa<IntegerType>(destElemType) && bitTrunc && destElemWidth > 1) {
// Quantization: float to smaller int. First determine the intermediary
// type, same integer type as destination type, with the same type width as
// the source float type.
Type step1ElementType;
IntegerType destIntType = mlir::cast<IntegerType>(destElemType);
bool destIssSigned = destIntType.isSignless() || destIntType.isSigned();
if (destIssSigned)
step1ElementType = b().getIntegerType(srcElemWidth);
else
step1ElementType = b().getIntegerType(srcElemWidth, false);
// Perform (recursively) the 2 step conversion. Exceptionally ok here to use
// element type here as cast will promote it to a vector if src is a vector.
Value step1Val = cast(step1ElementType, src);
return cast(destType, step1Val);
}
if (mlir::isa<IntegerType>(srcElemType) &&
mlir::isa<FloatType>(destElemType) && bitExtend) {
// Dequantization: small int to a float. First determine the intermediary
// type, same integer type as source type, with the same type width as
// the destination float type.
Type step1ElementType;
IntegerType srcIntType = mlir::cast<IntegerType>(srcElemType);
bool srcIssSigned = srcIntType.isSignless() || srcIntType.isSigned();
if (srcIssSigned)
step1ElementType = b().getIntegerType(destElemWidth);
else
step1ElementType = b().getIntegerType(destElemWidth, false);
// Perform (recursively) the 2 step conversion. Exceptionally ok here to use
// element type here as cast will promote it to a vector if src is a vector.
Value step1Val = cast(step1ElementType, src);
return cast(destType, step1Val);
}

// Handle boolean first because they need special handling.
// Boolean to int/float conversions. Boolean are unsigned.
if (srcElemType.isInteger(1)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@

// -----


func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor<f32>, %arg2: tensor<i8>) -> tensor<4xf32> {
%0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xi8>, tensor<f32>, tensor<i8>) -> tensor<4xf32>
return %0 : tensor<4xf32>

// mlir2FileCheck.py
// CHECK-LABEL: func.func @test_dequantizelinear_i8
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xi8>, [[PARAM_1_:%.+]]: memref<f32>, [[PARAM_2_:%.+]]: memref<i8>) -> memref<4xf32> {
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32>
Expand All @@ -18,12 +20,13 @@ func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor<f32>, %ar
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xi8>
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref<f32>
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref<i8>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_5_:%.+]] = arith.sitofp [[LOAD_PARAM_2_MEM_]] : i8 to f32
// CHECK-DAG: [[VAR_6_:%.+]] = arith.sitofp [[LOAD_PARAM_0_MEM_]] : i8 to f32
// CHECK: [[VAR_7_:%.+]] = arith.subf [[VAR_6_]], [[VAR_5_]] : f32
// CHECK: [[VAR_8_:%.+]] = arith.mulf [[VAR_7_]], [[LOAD_PARAM_1_MEM_]] : f32
// CHECK: krnl.store [[VAR_8_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
// CHECK: [[VAR_5_:%.+]] = arith.extsi [[LOAD_PARAM_0_MEM_]] : i8 to i32
// CHECK-DAG: [[VAR_6_:%.+]] = arith.sitofp [[VAR_5_]] : i32 to f32
// CHECK-DAG: [[VAR_7_:%.+]] = arith.extsi [[LOAD_PARAM_2_MEM_]] : i8 to i32
// CHECK: [[VAR_8_:%.+]] = arith.sitofp [[VAR_7_]] : i32 to f32
// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32
// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32
// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
// CHECK: }
// CHECK: return [[RES_]] : memref<4xf32>
// CHECK: }
Expand All @@ -47,12 +50,14 @@ func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor<f32>, %
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref<f32>
// CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref<ui8>
// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8
// CHECK-DAG: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i8 to f32
// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8
// CHECK: [[VAR_8_:%.+]] = arith.uitofp [[VAR_7_]] : i8 to f32
// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32
// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32
// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
// CHECK: [[VAR_6_:%.+]] = arith.extui [[VAR_5_]] : i8 to i32
// CHECK-DAG: [[VAR_7_:%.+]] = arith.uitofp [[VAR_6_]] : i32 to f32
// CHECK-DAG: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8
// CHECK: [[VAR_9_:%.+]] = arith.extui [[VAR_8_]] : i8 to i32
// CHECK: [[VAR_10_:%.+]] = arith.uitofp [[VAR_9_]] : i32 to f32
// CHECK: [[VAR_11_:%.+]] = arith.subf [[VAR_7_]], [[VAR_10_]] : f32
// CHECK: [[VAR_12_:%.+]] = arith.mulf [[VAR_11_]], [[LOAD_PARAM_1_MEM_]] : f32
// CHECK: krnl.store [[VAR_12_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32>
// CHECK: }
// CHECK: return [[RES_]] : memref<4xf32>
// CHECK: }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,19 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
// CHECK-DAG: [[VAR_23_:%.+]] = arith.select [[VAR_21_]], [[VAR_22_]], [[VAR_12_]] : f32
// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_5_dot_000000_]] : f32
// CHECK: [[VAR_25_:%.+]] = arith.select [[VAR_24_]], [[VAR_23_]], [[VAR_16_]] : f32
// CHECK: [[VAR_26_:%.+]] = arith.fptoui [[VAR_25_]] : f32 to i8
// CHECK: [[VAR_27_:%.+]] = builtin.unrealized_conversion_cast [[VAR_26_]] : i8 to ui8
// CHECK: [[VAR_26_:%.+]] = arith.fptoui [[VAR_25_]] : f32 to i32
// CHECK: [[VAR_27_:%.+]] = arith.trunci [[VAR_26_]] : i32 to i8
// CHECK: [[VAR_28_:%.+]] = builtin.unrealized_conversion_cast [[VAR_27_]] : i8 to ui8
// CHECK: krnl.store [[VAR_7_]], [[RES_1_]][] : memref<f32>
// CHECK: krnl.store [[VAR_27_]], [[RES_2_]][] : memref<ui8>
// CHECK-DAG: [[VAR_28_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK: krnl.store [[VAR_28_]], [[RES_2_]][] : memref<ui8>
// CHECK-DAG: [[VAR_29_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[VAR_28_]], [[RES_5_]][0] : memref<1xindex>
// CHECK: affine.store [[VAR_29_]], [[RES_5_]][0] : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_5_]]) : (memref<?x2xf32>, memref<1xindex>) -> memref<?xf32>
// CHECK-DAG: [[VAR_29_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[VAR_30_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}}
// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex>
// CHECK: affine.store [[VAR_29_]], [[RES_6_]][0] : memref<1xindex>
// CHECK: affine.store [[VAR_30_]], [[RES_6_]][0] : memref<1xindex>
// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref<?x2xui8>, memref<1xindex>) -> memref<?xui8>
// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc([[VAR_28_]]) {{.*}}: memref<?xf32>
// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){
// CHECK: [[VAR_32_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index
Expand All @@ -112,15 +112,10 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_25_]] : f32
// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[CST_0_dot_000000_]] : f32
// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[CST_2_dot_550000_]] : f32
// CHECK: krnl.store [[VAR_51_]], [[RES_7_]]{{.}}[[VAR_32_2_]]{{.}} : memref<?xf32>
// CHECK: }
// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1
// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){
// CHECK: [[VAR_32_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_32_3_]]{{.}} : memref<?xf32>
// CHECK: [[LOAD_RES_3_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_1_]] : f32 to i8
// CHECK: [[VAR_35_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_RES_3_MEM_1_1_]] : i8 to ui8
// CHECK: krnl.store [[VAR_35_3_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_3_]]{{.}} : memref<?xui8>
// CHECK: [[VAR_52_:%.+]] = arith.fptoui [[VAR_51_]] : f32 to i32
// CHECK: [[VAR_53_:%.+]] = arith.trunci [[VAR_52_]] : i32 to i8
// CHECK: [[VAR_54_:%.+]] = builtin.unrealized_conversion_cast [[VAR_53_]] : i8 to ui8
// CHECK: krnl.store [[VAR_54_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_2_]]{{.}} : memref<?xui8>
// CHECK: }
// CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref<?x2xui8>, memref<f32>, memref<ui8>
// CHECK: }
Expand Down
Loading
Loading