From 3978448655aa05b0039a4fd92c8965b95553e72c Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Wed, 25 Sep 2024 11:38:36 -0400 Subject: [PATCH 1/9] new attempt at linear quant opt Signed-off-by: Alexandre Eichenberger --- .../Quantization/QuantizeLinear.cpp | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index d7af519ea9..30248f0a44 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -71,6 +71,67 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, DimsExpr outputAF; outputAF.emplace_back(zero); +#if 1 + // hi alex: test with 2 loops for easier debugging + // Allocate output buffers. + MemRefType flatBufferType = llvm::cast(flatInput.getType()); + Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims); + DimsExpr bufferAF; + bufferAF.emplace_back(zero); + + create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, + {flatInput}, {inputAF}, {flatBuffer}, {bufferAF}, + {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { + MultiDialectBuilder create(kb); + Value x = inputVals[0]; + // Scale + Value scaleX = create.math.div(x, scale); + // Round + Value roundX = create.math.round(scaleX); + // Adjust + Value adjustX; + if (hasZeroPoint) + adjustX = create.math.add(roundX, zeroPoint); + else + adjustX = roundX; + // Saturate: use max into a min. + Value saturateX = create.math.clip(adjustX, qMin, qMax); + // Old approach. + // return create.math.cast(quantizedElementType, saturateX); + return saturateX; + }}); + + // Need transient types. + Type inputElementType = flatBufferType.getElementType(); + unsigned inputWidth; + if (isa(inputElementType)) + inputWidth = 32; + else if (isa(inputElementType)) + inputWidth = 64; + else + llvm_unreachable("unsupported input type"); + IntegerType quantizedIntType = cast(quantizedElementType); + // hi alex unsigned quantizedWidth = quantizedIntType.getWidth(); + bool isSignless = quantizedIntType.isSignless(); + bool isSigned = quantizedIntType.isSigned(); + Type quantizedElementTypeSameSizeAsInput = + rewriter.getIntegerType(inputWidth, isSignless || isSigned); + + create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, + {flatBuffer}, {bufferAF}, {flatAlloc}, {outputAF}, + {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { + MultiDialectBuilder create(kb); + // Convert float* to int*/uint* where * is 32 (64?) + Value input = inputVals[0]; + Value quantizedSameSizeAsInput = + create.math.cast(quantizedElementTypeSameSizeAsInput, input); + // Convert int32/uint32 to int*/unint* where * is 8, 16... + Value quantizedSameSizeAsOutput = + create.math.cast(quantizedElementType, quantizedSameSizeAsInput); + return quantizedSameSizeAsOutput; + }}); + +#else // faster than original loop on z16, takes 124us for 64k vals // Allocate output buffers. MemRefType flatBufferType = llvm::cast(flatInput.getType()); @@ -117,6 +178,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, Value res = create.math.cast(quantizedElementType, buffVal); create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]}); }); +#endif if (totVL > 1) onnxToKrnlSimdReport(op, /*successful*/ true, totVL, From ac67673756e42d3e065134a052f1e955dd15b8a7 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Wed, 25 Sep 2024 12:28:27 -0400 Subject: [PATCH 2/9] in steps Signed-off-by: Alexandre Eichenberger --- .../Quantization/QuantizeLinear.cpp | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 30248f0a44..ebbf4614bb 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -73,7 +73,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, #if 1 // hi alex: test with 2 loops for easier debugging - // Allocate output buffers. + // Allocate output buffers (same type as input). MemRefType flatBufferType = llvm::cast(flatInput.getType()); Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims); DimsExpr bufferAF; @@ -111,7 +111,6 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, else llvm_unreachable("unsupported input type"); IntegerType quantizedIntType = cast(quantizedElementType); - // hi alex unsigned quantizedWidth = quantizedIntType.getWidth(); bool isSignless = quantizedIntType.isSignless(); bool isSigned = quantizedIntType.isSigned(); Type quantizedElementTypeSameSizeAsInput = @@ -125,10 +124,23 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, Value input = inputVals[0]; Value quantizedSameSizeAsInput = create.math.cast(quantizedElementTypeSameSizeAsInput, input); - // Convert int32/uint32 to int*/unint* where * is 8, 16... - Value quantizedSameSizeAsOutput = + // Convert int32/uint32 to int*/unint* where * is 8, 16... +#if 0 + // Code get normalized to the code below + unsigned quantizedWidth = quantizedIntType.getWidth(); + unsigned currWidth = inputWidth; + Value qVal = quantizedSameSizeAsInput; + while (currWidth > quantizedWidth) { + currWidth = currWidth / 2; + Type qType = + rewriter.getIntegerType(currWidth, isSignless || isSigned); + qVal = create.math.cast(qType, qVal); + } +#else + Value qVal = create.math.cast(quantizedElementType, quantizedSameSizeAsInput); - return quantizedSameSizeAsOutput; +#endif + return qVal; }}); #else @@ -264,9 +276,10 @@ struct ONNXQuantizeLinearOpLowering hasZeroPoint = true; } if (disableQuantZeroPoint) { - // TODO: should we expect to disable hasZeroPoint forcefully, or generate - // an error if we had a zero point? Right now, just forcefully assert we - // have no zero point, i.e. ignore one even if we had a zero point. + // TODO: should we expect to disable hasZeroPoint forcefully, or + // generate an error if we had a zero point? Right now, just forcefully + // assert we have no zero point, i.e. ignore one even if we had a zero + // point. hasZeroPoint = false; } emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, From d30d7c7eee4587fef7973420da53798d8aa690b7 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Wed, 25 Sep 2024 12:48:36 -0400 Subject: [PATCH 3/9] 2 steps into one loop Signed-off-by: Alexandre Eichenberger --- .../Quantization/QuantizeLinear.cpp | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index ebbf4614bb..4a80100915 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -72,6 +72,46 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, outputAF.emplace_back(zero); #if 1 + Type inputElementType = inputType.getElementType(); + unsigned inputWidth; + if (isa(inputElementType)) + inputWidth = 32; + else if (isa(inputElementType)) + inputWidth = 64; + else + llvm_unreachable("unsupported input type"); + IntegerType quantizedIntType = cast(quantizedElementType); + bool isSigned = quantizedIntType.isSignless() || quantizedIntType.isSigned(); + Type quantizedElementTypeInputSized = + rewriter.getIntegerType(inputWidth, isSigned); + + create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, + {flatInput}, {inputAF}, {flatAlloc}, {outputAF}, + {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { + MultiDialectBuilder create(kb); + Value x = inputVals[0]; + // Scale + Value scaleX = create.math.div(x, scale); + // Round + Value roundX = create.math.round(scaleX); + // Adjust + Value adjustX; + if (hasZeroPoint) + adjustX = create.math.add(roundX, zeroPoint); + else + adjustX = roundX; + // Saturate: use max into a min. + Value saturateX = create.math.clip(adjustX, qMin, qMax); + // Convert float* to int*/uint* where * is 64/32. + Value qSaturateXInputSized = + create.math.cast(quantizedElementTypeInputSized, saturateX); + // Reduce quantized precision. + Value res = + create.math.cast(quantizedElementType, qSaturateXInputSized); + return res; + }}); + +#elif 1 // hi alex: test with 2 loops for easier debugging // Allocate output buffers (same type as input). MemRefType flatBufferType = llvm::cast(flatInput.getType()); From 79a8875e0eb1a05cd4a469c98725923c3007130c Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Wed, 25 Sep 2024 13:54:19 -0400 Subject: [PATCH 4/9] removed alternative code versions Signed-off-by: Alexandre Eichenberger --- .../Quantization/QuantizeLinear.cpp | 122 ------------------ 1 file changed, 122 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 4a80100915..b28974f2c4 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -71,7 +71,6 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, DimsExpr outputAF; outputAF.emplace_back(zero); -#if 1 Type inputElementType = inputType.getElementType(); unsigned inputWidth; if (isa(inputElementType)) @@ -111,127 +110,6 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, return res; }}); -#elif 1 - // hi alex: test with 2 loops for easier debugging - // Allocate output buffers (same type as input). - MemRefType flatBufferType = llvm::cast(flatInput.getType()); - Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims); - DimsExpr bufferAF; - bufferAF.emplace_back(zero); - - create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, - {flatInput}, {inputAF}, {flatBuffer}, {bufferAF}, - {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { - MultiDialectBuilder create(kb); - Value x = inputVals[0]; - // Scale - Value scaleX = create.math.div(x, scale); - // Round - Value roundX = create.math.round(scaleX); - // Adjust - Value adjustX; - if (hasZeroPoint) - adjustX = create.math.add(roundX, zeroPoint); - else - adjustX = roundX; - // Saturate: use max into a min. - Value saturateX = create.math.clip(adjustX, qMin, qMax); - // Old approach. - // return create.math.cast(quantizedElementType, saturateX); - return saturateX; - }}); - - // Need transient types. - Type inputElementType = flatBufferType.getElementType(); - unsigned inputWidth; - if (isa(inputElementType)) - inputWidth = 32; - else if (isa(inputElementType)) - inputWidth = 64; - else - llvm_unreachable("unsupported input type"); - IntegerType quantizedIntType = cast(quantizedElementType); - bool isSignless = quantizedIntType.isSignless(); - bool isSigned = quantizedIntType.isSigned(); - Type quantizedElementTypeSameSizeAsInput = - rewriter.getIntegerType(inputWidth, isSignless || isSigned); - - create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, - {flatBuffer}, {bufferAF}, {flatAlloc}, {outputAF}, - {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { - MultiDialectBuilder create(kb); - // Convert float* to int*/uint* where * is 32 (64?) - Value input = inputVals[0]; - Value quantizedSameSizeAsInput = - create.math.cast(quantizedElementTypeSameSizeAsInput, input); - // Convert int32/uint32 to int*/unint* where * is 8, 16... -#if 0 - // Code get normalized to the code below - unsigned quantizedWidth = quantizedIntType.getWidth(); - unsigned currWidth = inputWidth; - Value qVal = quantizedSameSizeAsInput; - while (currWidth > quantizedWidth) { - currWidth = currWidth / 2; - Type qType = - rewriter.getIntegerType(currWidth, isSignless || isSigned); - qVal = create.math.cast(qType, qVal); - } -#else - Value qVal = - create.math.cast(quantizedElementType, quantizedSameSizeAsInput); -#endif - return qVal; - }}); - -#else - // faster than original loop on z16, takes 124us for 64k vals - // Allocate output buffers. - MemRefType flatBufferType = llvm::cast(flatInput.getType()); - Value flatBuffer = create.mem.alignedAlloc(flatBufferType, flatInputDims); - DimsExpr bufferAF; - bufferAF.emplace_back(zero); - - create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, - {flatInput}, {inputAF}, {flatBuffer}, {bufferAF}, - {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { - MultiDialectBuilder create(kb); - Value x = inputVals[0]; - // Scale - Value scaleX = create.math.div(x, scale); - // Round - Value roundX = create.math.round(scaleX); - // Adjust - Value adjustX; - if (hasZeroPoint) - adjustX = create.math.add(roundX, zeroPoint); - else - adjustX = roundX; - // Saturate: use max into a min. - Value saturateX = create.math.clip(adjustX, qMin, qMax); - // Old approach. - // return create.math.cast(quantizedElementType, saturateX); - return saturateX; - }}); - - // A second loop that performs scalar float to int performs better than the - // compiler's attempt to generate SIMD conversion code. This might not hold - // with all data types, but is definitely noticeable with uint8. - // - // Investigate further: we might save the vector to a buffer on the fly - // (avoiding a second loop as below), and then reload each value as scalar and - // then saved them as scalar (thus avoiding the insert/extract SIMD operations - // that also do not perform well). We can have a SIMD buffer in memory for the - // non-quantized and quantized simd values, but then we also need to privatize - // it, which is also not easy in this scheme. So ignore this for now. - create.krnl.forLoopIE(simdLb, simdUb, 1, enableParallel, - [&](const KrnlBuilder &kb, ValueRange loopInd) { - MultiDialectBuilder create(kb); - Value buffVal = create.krnl.loadIE(flatBuffer, {zero}, {loopInd[0]}); - Value res = create.math.cast(quantizedElementType, buffVal); - create.krnl.storeIE(res, flatAlloc, {zero}, {loopInd[0]}); - }); -#endif - if (totVL > 1) onnxToKrnlSimdReport(op, /*successful*/ true, totVL, simdLoopStaticTripCount, "quantizationLinear whole tensor"); From 0dffbc9d95d1e3a99973e6aadd13b04a708bf7da Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Wed, 25 Sep 2024 15:48:12 -0400 Subject: [PATCH 5/9] larger unroll Signed-off-by: Alexandre Eichenberger --- src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index b28974f2c4..cd336a19ec 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -56,7 +56,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, {GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3}, {GenericOps::FloorGop, 2}, {GenericOps::EstimatedVectorRegisterPressure, - 8 /* Little parallelism in code. */}}; + 4 /* Little parallelism in code. */}}; totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/, innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount, simdOnly); From 34b97da853b296f43464077583ee1821fd70494d Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Wed, 25 Sep 2024 16:13:27 -0400 Subject: [PATCH 6/9] fix lit tests Signed-off-by: Alexandre Eichenberger --- .../Quantization/QuantizeLinear.cpp | 16 +- ...namicQuantizeLinear_with_canonicalize.mlir | 68 ++- ...QuantizeLinear_with_simd_canonicalize.mlir | 251 ++++++----- ...inear_with_simd_parallel_canonicalize.mlir | 398 +++++++++--------- .../QuantizationWithoutZeroPoint.mlir | 118 +++--- .../QuantizeLinear_with_canonicalize.mlir | 106 +++-- 6 files changed, 452 insertions(+), 505 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index cd336a19ec..65bba9deb0 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -81,9 +81,19 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, llvm_unreachable("unsupported input type"); IntegerType quantizedIntType = cast(quantizedElementType); bool isSigned = quantizedIntType.isSignless() || quantizedIntType.isSigned(); - Type quantizedElementTypeInputSized = - rewriter.getIntegerType(inputWidth, isSigned); - + Type quantizedElementTypeInputSized; + if (isSigned) { + // Cannot use getIntegerType(inputWidth, true) as it returns signed ints. + if (inputWidth == 64) + quantizedElementTypeInputSized = rewriter.getI64Type(); + else if (inputWidth == 32) + quantizedElementTypeInputSized = rewriter.getI32Type(); + else + llvm_unreachable("unsupported input type"); + } else { + // unsigned of the right type + quantizedElementTypeInputSized = rewriter.getIntegerType(inputWidth, false); + } create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, {flatInput}, {inputAF}, {flatAlloc}, {outputAF}, {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir index 934fba2240..022c7ece38 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir @@ -31,22 +31,22 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ -// CHECK: [[VAR_32_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_]]#0, [[VAR_32_]]#1] : memref +// CHECK: [[VAR_31_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_]]#0, [[VAR_31_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref -// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_35_]], [[RES_3_]][] : memref +// CHECK: [[VAR_34_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_34_]], [[RES_3_]][] : memref // CHECK: } // CHECK: [[RES_4_:%.+]] = memref.alloc() : memref // CHECK: krnl.memset [[RES_4_]], [[CST_0_]] : memref // CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 // CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref // CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ -// CHECK: [[VAR_32_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_1_]]#0, [[VAR_32_1_]]#1] : memref +// CHECK: [[VAR_31_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_1_]]#0, [[VAR_31_1_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK: [[VAR_35_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 -// CHECK: krnl.store [[VAR_35_1_]], [[RES_4_]][] : memref +// CHECK: [[VAR_34_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_34_1_]], [[RES_4_]][] : memref // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref @@ -87,40 +87,34 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: affine.store [[VAR_29_]], [[RES_6_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc([[VAR_28_]]) {{.*}}: memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ -// CHECK: [[VAR_32_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_32_2_]]{{.}} : memref +// CHECK: [[VAR_31_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_31_2_]]{{.}} : memref // CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 -// CHECK: [[VAR_35_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 -// CHECK: [[VAR_36_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_35_2_]] : f32 -// CHECK-DAG: [[VAR_37_:%.+]] = arith.cmpf ogt, [[VAR_36_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_34_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_35_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_34_2_]] : f32 +// CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpf ogt, [[VAR_35_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_39_:%.+]] = arith.select [[VAR_37_]], [[VAR_38_]], [[VAR_35_2_]] : f32 -// CHECK-DAG: [[VAR_40_:%.+]] = arith.mulf [[VAR_35_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_41_:%.+]] = math.floor [[VAR_40_]] : f32 -// CHECK: [[VAR_42_:%.+]] = arith.mulf [[VAR_41_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_43_:%.+]] = arith.subf [[VAR_35_2_]], [[VAR_42_]] : f32 -// CHECK-DAG: [[VAR_44_:%.+]] = arith.cmpf oeq, [[VAR_43_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_45_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_38_:%.+]] = arith.select [[VAR_36_]], [[VAR_37_]], [[VAR_34_2_]] : f32 +// CHECK-DAG: [[VAR_39_:%.+]] = arith.mulf [[VAR_34_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_40_:%.+]] = math.floor [[VAR_39_]] : f32 +// CHECK: [[VAR_41_:%.+]] = arith.mulf [[VAR_40_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_42_:%.+]] = arith.subf [[VAR_34_2_]], [[VAR_41_]] : f32 +// CHECK-DAG: [[VAR_43_:%.+]] = arith.cmpf oeq, [[VAR_42_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_44_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_46_:%.+]] = arith.select [[VAR_44_]], [[VAR_45_]], [[VAR_35_2_]] : f32 -// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_36_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_48_:%.+]] = arith.select [[VAR_47_]], [[VAR_46_]], [[VAR_39_]] : f32 -// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_25_]] : f32 -// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[CST_2_dot_550000_]] : f32 -// CHECK: krnl.store [[VAR_51_]], [[RES_7_]]{{.}}[[VAR_32_2_]]{{.}} : memref -// CHECK: } -// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ -// CHECK: [[VAR_32_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_32_3_]]{{.}} : memref -// CHECK: [[LOAD_RES_3_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_1_]] : f32 to i8 -// CHECK: [[VAR_35_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_RES_3_MEM_1_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_35_3_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_3_]]{{.}} : memref +// CHECK-DAG: [[VAR_45_:%.+]] = arith.select [[VAR_43_]], [[VAR_44_]], [[VAR_34_2_]] : f32 +// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_35_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_47_:%.+]] = arith.select [[VAR_46_]], [[VAR_45_]], [[VAR_38_]] : f32 +// CHECK: [[VAR_48_:%.+]] = arith.addf [[VAR_47_]], [[VAR_25_]] : f32 +// CHECK: [[VAR_49_:%.+]] = arith.maxnumf [[VAR_48_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_50_:%.+]] = arith.minnumf [[VAR_49_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_51_:%.+]] = arith.fptoui [[VAR_50_]] : f32 to i32 +// CHECK: [[VAR_52_:%.+]] = arith.trunci [[VAR_51_]] : i32 to i8 +// CHECK: [[VAR_53_:%.+]] = builtin.unrealized_conversion_cast [[VAR_52_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_53_]], [[VAR_reshape_14_]]{{.}}[[VAR_31_2_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir index 637cd5fdaf..25f2a50499 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir @@ -42,16 +42,16 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<4096xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<4096xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> @@ -97,44 +97,38 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> -// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<4096xf32> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xf32>, vector<16xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<16xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<16xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> -// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<16xf32> -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<16xf32> +// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<16xf32> -// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<16xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<16xf32> +// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<16xf32> -// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<16xf32> -// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: vector.store [[VAR_54_]], [[RES_10_]]{{.}}[[VAR_33_1_]]{{.}} : memref<4096xf32>, vector<16xf32> -// CHECK: } -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[RES_10_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xf32> -// CHECK: [[LOAD_VAR_reshape_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_VAR_reshape_MEM_2_]] : f32 to i8 -// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_VAR_reshape_MEM_1_1_]] : i8 to ui8 -// CHECK: krnl.store [[LOAD_RES_4_MEM_2_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xui8> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<16xf32> +// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<16xf32> +// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_55_:%.+]] = arith.trunci [[VAR_54_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_56_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_56_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xui8>, vector<16xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<256x16xui8>, memref, memref // CHECK: } @@ -179,29 +173,29 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4304){ -// CHECK: [[VAR_35_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_35_]]{{.}} : memref<4335xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_35_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_34_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_34_]]{{.}} : memref<4335xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_40_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_40_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_35_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_35_1_]]{{.}} : memref<4335xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_35_1_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_34_1_]]{{.}} : memref<4335xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_34_1_]]{{.}} : memref<4335xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: krnl.store [[VAR_40_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> -// CHECK: krnl.store [[VAR_41_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> @@ -247,70 +241,67 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> -// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<4335xf32> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4320){ -// CHECK: [[VAR_35_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<16xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> // CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<16xf32> // CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = math.floor [[LOAD_RES_4_MEM_1_]] : vector<16xf32> -// CHECK: [[VAR_40_2_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_]], [[LOAD_RES_6_MEM_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.cmpf ogt, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_42_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_]], [[LOAD_RES_6_MEM_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_43_:%.+]] = arith.select [[VAR_41_2_]], [[VAR_42_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_44_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK: [[VAR_45_:%.+]] = math.floor [[VAR_44_]] : vector<16xf32> -// CHECK: [[VAR_46_:%.+]] = arith.mulf [[VAR_45_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_47_:%.+]] = arith.subf [[LOAD_RES_6_MEM_1_]], [[VAR_46_]] : vector<16xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_47_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_43_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = math.floor [[VAR_43_]] : vector<16xf32> +// CHECK: [[VAR_45_:%.+]] = arith.mulf [[VAR_44_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_46_:%.+]] = arith.subf [[LOAD_RES_6_MEM_1_]], [[VAR_45_]] : vector<16xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_46_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_48_]], [[VAR_49_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = arith.cmpf oeq, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_47_]], [[VAR_48_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_52_:%.+]] = arith.select [[VAR_51_]], [[VAR_50_]], [[VAR_43_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_53_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.addf [[VAR_52_]], [[VAR_53_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.maxnumf [[VAR_54_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_56_:%.+]] = arith.minnumf [[VAR_55_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: vector.store [[VAR_56_]], [[RES_10_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.select [[VAR_50_]], [[VAR_49_]], [[VAR_42_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_52_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_]], [[VAR_52_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_57_:%.+]] = arith.trunci [[VAR_56_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_58_:%.+]] = builtin.unrealized_conversion_cast [[VAR_57_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_58_]], [[VAR_reshape_21_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<16xui8> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_35_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_19_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_19_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xf32> // CHECK: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_11_]] : f32 // CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = math.floor [[LOAD_VAR_reshape_MEM_3_1_]] : f32 // CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_3_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.cmpf ogt, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_41_3_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_39_3_:%.+]] = arith.cmpf ogt, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_3_]], [[VAR_41_3_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_44_1_:%.+]] = math.floor [[VAR_43_1_]] : f32 -// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[VAR_44_1_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_1_]], [[VAR_45_1_]] : f32 -// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.select [[VAR_39_3_]], [[VAR_40_3_]], [[LOAD_RES_4_MEM_1_1_]] : f32 +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_43_1_:%.+]] = math.floor [[VAR_42_1_]] : f32 +// CHECK: [[VAR_44_1_:%.+]] = arith.mulf [[VAR_43_1_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_45_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_1_]], [[VAR_44_1_]] : f32 +// CHECK-DAG: [[VAR_46_1_:%.+]] = arith.cmpf oeq, [[VAR_45_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.select [[VAR_47_1_]], [[VAR_48_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_50_1_:%.+]] = arith.cmpf oeq, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_51_1_:%.+]] = arith.select [[VAR_50_1_]], [[VAR_49_1_]], [[VAR_42_1_]] : f32 -// CHECK: [[VAR_52_1_:%.+]] = arith.addf [[VAR_51_1_]], [[VAR_29_]] : f32 -// CHECK: [[VAR_53_1_:%.+]] = arith.maxnumf [[VAR_52_1_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_54_1_:%.+]] = arith.minnumf [[VAR_53_1_]], [[CST_2_dot_550000_]] : f32 -// CHECK: krnl.store [[VAR_54_1_]], [[RES_10_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> -// CHECK: } -// CHECK: [[LOOP_4_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_4_]]) with ([[LOOP_4_]] -> [[I_4_:%.+]] = 0 to 4335){ -// CHECK: [[VAR_35_4_:%.+]] = krnl.get_induction_var_value([[LOOP_4_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[RES_10_]]{{.}}[[VAR_35_4_]]{{.}} : memref<4335xf32> -// CHECK: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.fptoui [[LOAD_VAR_reshape_MEM_2_1_]] : f32 to i8 -// CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_VAR_reshape_MEM_3_1_]] : i8 to ui8 -// CHECK: krnl.store [[LOAD_RES_4_MEM_1_1_]], [[VAR_reshape_21_]]{{.}}[[VAR_35_4_]]{{.}} : memref<4335xui8> +// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.select [[VAR_46_1_]], [[VAR_47_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 +// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.cmpf oeq, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_50_1_:%.+]] = arith.select [[VAR_49_1_]], [[VAR_48_1_]], [[VAR_41_1_]] : f32 +// CHECK: [[VAR_51_1_:%.+]] = arith.addf [[VAR_50_1_]], [[VAR_29_]] : f32 +// CHECK: [[VAR_52_1_:%.+]] = arith.maxnumf [[VAR_51_1_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_53_1_:%.+]] = arith.minnumf [[VAR_52_1_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_54_1_:%.+]] = arith.fptoui [[VAR_53_1_]] : f32 to i32 +// CHECK: [[VAR_55_1_:%.+]] = arith.trunci [[VAR_54_1_]] : i32 to i8 +// CHECK: [[VAR_56_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_56_1_]], [[VAR_reshape_21_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<255x17xui8>, memref, memref // CHECK: } @@ -355,16 +346,16 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> @@ -410,44 +401,38 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> -// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> +// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> +// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> +// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<8xf32> +// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: vector.store [[VAR_54_]], [[RES_10_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: } -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[RES_10_]]{{.}}[[VAR_33_2_]]{{.}} : memref<8xf32> -// CHECK: [[LOAD_VAR_reshape_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_VAR_reshape_MEM_2_]] : f32 to i8 -// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_VAR_reshape_MEM_1_1_]] : i8 to ui8 -// CHECK: krnl.store [[LOAD_RES_4_MEM_2_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_2_]]{{.}} : memref<8xui8> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> +// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> +// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi32> +// CHECK: [[VAR_55_:%.+]] = arith.trunci [[VAR_54_]] : vector<8xi32> to vector<8xi8> +// CHECK: [[VAR_56_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_56_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xui8>, vector<8xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir index dd048e3ab7..f73ec71385 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir @@ -49,46 +49,46 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_:%.+]] = affine.apply [[MAP_0_]]([[VAR_34_]]) -// CHECK-DAG: [[VAR_36_:%.+]] = affine.min [[MAP_1_]]([[VAR_34_]]) -// CHECK-DAG: [[VAR_37_:%.+]] = affine.apply [[MAP_2_]]([[VAR_34_]]) -// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_3_]]([[VAR_34_]]) -// CHECK: scf.for [[I_1_:%.+]] = [[VAR_35_]] to [[VAR_38_]] step [[CST_32_]] { +// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_34_:%.+]] = affine.apply [[MAP_0_]]([[VAR_33_]]) +// CHECK-DAG: [[VAR_35_:%.+]] = affine.min [[MAP_1_]]([[VAR_33_]]) +// CHECK-DAG: [[VAR_36_:%.+]] = affine.apply [[MAP_2_]]([[VAR_33_]]) +// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_37_:%.+]] = affine.min [[MAP_3_]]([[VAR_33_]]) +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_34_]] to [[VAR_37_]] step [[CST_32_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_52_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_51_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_52_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_50_]], [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_51_]], [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK: } -// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_4_]]([[VAR_34_]]) -// CHECK: [[VAR_40_:%.+]] = arith.remsi [[VAR_39_]], [[CST_32_]] : index -// CHECK: [[VAR_41_:%.+]] = arith.subi [[VAR_39_]], [[VAR_40_]] : index -// CHECK: [[VAR_42_:%.+]] = arith.addi [[VAR_35_]], [[VAR_41_]] : index -// CHECK: scf.for [[I_2_:%.+]] = [[VAR_42_]] to [[VAR_36_]] step [[CST_1_]] { +// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_4_]]([[VAR_33_]]) +// CHECK: [[VAR_39_:%.+]] = arith.remsi [[VAR_38_]], [[CST_32_]] : index +// CHECK: [[VAR_40_:%.+]] = arith.subi [[VAR_38_]], [[VAR_39_]] : index +// CHECK: [[VAR_41_:%.+]] = arith.addi [[VAR_34_]], [[VAR_40_]] : index +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_41_]] to [[VAR_35_]] step [[CST_1_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: memref.store [[VAR_51_1_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> -// CHECK: memref.store [[VAR_52_1_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_50_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_50_1_]], [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_51_1_]], [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 -// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 -// CHECK: memref.store [[VAR_45_]], [[RES_5_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> -// CHECK: memref.store [[VAR_46_]], [[RES_7_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_44_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_44_]], [[RES_5_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_45_]], [[RES_7_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32> // CHECK: } // CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref @@ -96,16 +96,16 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ -// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> +// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_34_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_35_1_]] : f32 -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_36_1_]] : f32 -// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> -// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_34_1_]] : f32 +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_35_1_]] : f32 +// CHECK: krnl.store [[VAR_38_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[VAR_39_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> @@ -151,46 +151,39 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_11_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_25_:%.+]] = memref.reshape [[RES_]]([[RES_]]_24) : (memref<256x16xui8>, memref<1xindex>) -> memref<4096xui8> -// CHECK-DAG: [[RES_12_:%.+]] = memref.alloc() {{.*}}: memref<4096xf32> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4096xf32>, vector<16xf32> -// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_35_1_]], [[VAR_36_2_]] : vector<16xf32> -// CHECK: [[VAR_38_1_:%.+]] = math.floor [[VAR_37_1_]] : vector<16xf32> -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_38_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_1_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_34_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_35_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_36_1_:%.+]] = arith.divf [[VAR_34_1_]], [[VAR_35_2_]] : vector<16xf32> +// CHECK: [[VAR_37_1_:%.+]] = math.floor [[VAR_36_1_]] : vector<16xf32> +// CHECK: [[VAR_38_2_:%.+]] = arith.subf [[VAR_36_1_]], [[VAR_37_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_39_2_:%.+]] = arith.cmpf ogt, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.select [[VAR_39_2_]], [[VAR_40_1_]], [[VAR_37_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_37_1_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> -// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_45_1_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_42_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_52_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> -// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_2_]], [[VAR_52_2_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: vector.store [[VAR_55_]], [[RES_12_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4096xf32>, vector<16xf32> -// CHECK: } -// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.parallel([[LOOP_3_]]) : !krnl.loop -// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[VAR_35_1_1_:%.+]] = krnl.load [[RES_12_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4096xf32> -// CHECK: [[VAR_36_3_:%.+]] = arith.fptoui [[VAR_35_1_1_]] : f32 to i8 -// CHECK: [[VAR_37_2_:%.+]] = builtin.unrealized_conversion_cast [[VAR_36_3_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_37_2_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4096xui8> +// CHECK: [[VAR_44_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_45_1_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_44_1_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_45_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_37_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_50_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_41_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_51_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_2_]], [[VAR_51_2_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_56_:%.+]] = arith.trunci [[VAR_55_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_25_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xui8>, vector<16xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_13, [[RES_]]_14 : memref<256x16xui8>, memref, memref // CHECK: } @@ -242,46 +235,46 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_35_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_36_:%.+]] = affine.apply [[MAP_0_]]([[VAR_35_]]) -// CHECK-DAG: [[VAR_37_:%.+]] = affine.min [[MAP_1_]]([[VAR_35_]]) -// CHECK-DAG: [[VAR_38_:%.+]] = affine.apply [[MAP_2_]]([[VAR_35_]]) -// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_3_]]([[VAR_35_]]) -// CHECK: scf.for [[I_1_:%.+]] = [[VAR_36_]] to [[VAR_39_]] step [[CST_32_]] { +// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_35_:%.+]] = affine.apply [[MAP_0_]]([[VAR_34_]]) +// CHECK-DAG: [[VAR_36_:%.+]] = affine.min [[MAP_1_]]([[VAR_34_]]) +// CHECK-DAG: [[VAR_37_:%.+]] = affine.apply [[MAP_2_]]([[VAR_34_]]) +// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_3_]]([[VAR_34_]]) +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_35_]] to [[VAR_38_]] step [[CST_32_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_52_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_53_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_52_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_53_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_52_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_51_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_52_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK: } -// CHECK: [[VAR_40_:%.+]] = affine.min [[MAP_4_]]([[VAR_35_]]) -// CHECK: [[VAR_41_:%.+]] = arith.remsi [[VAR_40_]], [[CST_32_]] : index -// CHECK: [[VAR_42_:%.+]] = arith.subi [[VAR_40_]], [[VAR_41_]] : index -// CHECK: [[VAR_43_:%.+]] = arith.addi [[VAR_36_]], [[VAR_42_]] : index -// CHECK: scf.for [[I_2_:%.+]] = [[VAR_43_]] to [[VAR_37_]] step [[CST_1_]] { +// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_4_]]([[VAR_34_]]) +// CHECK: [[VAR_40_:%.+]] = arith.remsi [[VAR_39_]], [[CST_32_]] : index +// CHECK: [[VAR_41_:%.+]] = arith.subi [[VAR_39_]], [[VAR_40_]] : index +// CHECK: [[VAR_42_:%.+]] = arith.addi [[VAR_35_]], [[VAR_41_]] : index +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_42_]] to [[VAR_36_]] step [[CST_1_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_53_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: memref.store [[VAR_52_1_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> -// CHECK: memref.store [[VAR_53_1_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_51_1_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_52_1_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 -// CHECK-DAG: [[VAR_47_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 -// CHECK: memref.store [[VAR_46_]], [[RES_5_]]{{.}}[[VAR_35_]]{{.}} : memref<8xf32> -// CHECK: memref.store [[VAR_47_]], [[RES_7_]]{{.}}[[VAR_35_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_45_]], [[RES_5_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_46_]], [[RES_7_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> // CHECK: } // CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref @@ -289,16 +282,16 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ -// CHECK: [[VAR_35_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_35_1_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_37_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_35_1_]]{{.}} : memref<8xf32> +// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_36_1_]] : f32 -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_37_1_]] : f32 -// CHECK: krnl.store [[VAR_40_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> -// CHECK: krnl.store [[VAR_41_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_35_1_]] : f32 +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_36_1_]] : f32 +// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> @@ -344,72 +337,68 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_11_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_25_:%.+]] = memref.reshape [[RES_]]([[RES_]]_24) : (memref<255x17xui8>, memref<1xindex>) -> memref<4335xui8> -// CHECK-DAG: [[RES_12_:%.+]] = memref.alloc() {{.*}}: memref<4335xf32> // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4320){ -// CHECK: [[VAR_35_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_36_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> -// CHECK-DAG: [[VAR_37_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> -// CHECK: [[VAR_38_1_:%.+]] = arith.divf [[VAR_36_1_]], [[VAR_37_2_]] : vector<16xf32> -// CHECK: [[VAR_39_1_:%.+]] = math.floor [[VAR_38_1_]] : vector<16xf32> -// CHECK: [[VAR_40_2_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_39_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.cmpf ogt, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.addf [[VAR_39_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.select [[VAR_41_2_]], [[VAR_42_1_]], [[VAR_39_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_39_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_35_1_]], [[VAR_36_2_]] : vector<16xf32> +// CHECK: [[VAR_38_1_:%.+]] = math.floor [[VAR_37_1_]] : vector<16xf32> +// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_38_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_1_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> -// CHECK: [[VAR_46_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_47_1_:%.+]] = arith.subf [[VAR_39_1_]], [[VAR_46_1_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_47_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_39_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_39_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_52_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_43_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_53_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.addf [[VAR_52_2_]], [[VAR_53_2_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.maxnumf [[VAR_54_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_56_:%.+]] = arith.minnumf [[VAR_55_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: vector.store [[VAR_56_]], [[RES_12_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_45_1_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_51_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_42_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_52_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_2_]], [[VAR_52_2_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_57_:%.+]] = arith.trunci [[VAR_56_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_58_:%.+]] = builtin.unrealized_conversion_cast [[VAR_57_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_58_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<16xui8> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_35_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[VAR_36_1_1_:%.+]] = krnl.load [[VAR_reshape_23_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> -// CHECK: [[VAR_37_3_:%.+]] = arith.divf [[VAR_36_1_1_]], [[VAR_11_]] : f32 -// CHECK: [[VAR_38_2_:%.+]] = math.floor [[VAR_37_3_]] : f32 -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_3_]], [[VAR_38_2_]] : f32 -// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_41_3_:%.+]] = arith.addf [[VAR_38_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_2_:%.+]] = arith.select [[VAR_40_3_]], [[VAR_41_3_]], [[VAR_38_2_]] : f32 -// CHECK-DAG: [[VAR_43_2_:%.+]] = arith.mulf [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = math.floor [[VAR_43_2_]] : f32 +// CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_35_1_1_:%.+]] = krnl.load [[VAR_reshape_23_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_36_3_:%.+]] = arith.divf [[VAR_35_1_1_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_37_2_:%.+]] = math.floor [[VAR_36_3_]] : f32 +// CHECK: [[VAR_38_2_:%.+]] = arith.subf [[VAR_36_3_]], [[VAR_37_2_]] : f32 +// CHECK-DAG: [[VAR_39_3_:%.+]] = arith.cmpf ogt, [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.addf [[VAR_37_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.select [[VAR_39_3_]], [[VAR_40_3_]], [[VAR_37_2_]] : f32 +// CHECK-DAG: [[VAR_42_2_:%.+]] = arith.mulf [[VAR_37_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = math.floor [[VAR_42_2_]] : f32 // CHECK: [[LOAD_RES_6_MEM_2_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_2_1_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_46_2_:%.+]] = arith.subf [[VAR_38_2_]], [[LOAD_RES_6_MEM_2_1_]] : f32 -// CHECK-DAG: [[VAR_47_2_:%.+]] = arith.cmpf oeq, [[VAR_46_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = arith.addf [[VAR_38_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.select [[VAR_47_2_]], [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_38_2_]] : f32 -// CHECK-DAG: [[LOAD_RES_4_MEM_1_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.select [[LOAD_RES_4_MEM_1_1_]], [[LOAD_VAR_reshape_MEM_3_1_]], [[VAR_42_2_]] : f32 -// CHECK: [[VAR_52_3_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_1_]], [[VAR_29_]] : f32 -// CHECK: [[VAR_53_3_:%.+]] = arith.maxnumf [[VAR_52_3_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_54_1_:%.+]] = arith.minnumf [[VAR_53_3_]], [[CST_2_dot_550000_]] : f32 -// CHECK: krnl.store [[VAR_54_1_]], [[RES_12_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> -// CHECK: } -// CHECK: [[LOOP_4_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.parallel([[LOOP_4_]]) : !krnl.loop -// CHECK: krnl.iterate([[LOOP_4_]]) with ([[LOOP_4_]] -> [[I_6_:%.+]] = 0 to 4335){ -// CHECK: [[VAR_35_4_:%.+]] = krnl.get_induction_var_value([[LOOP_4_]]) : (!krnl.loop) -> index -// CHECK: [[VAR_36_1_1_:%.+]] = krnl.load [[RES_12_]]{{.}}[[VAR_35_4_]]{{.}} : memref<4335xf32> -// CHECK: [[VAR_37_4_:%.+]] = arith.fptoui [[VAR_36_1_1_]] : f32 to i8 -// CHECK: [[VAR_38_3_:%.+]] = builtin.unrealized_conversion_cast [[VAR_37_4_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_38_3_]], [[VAR_reshape_25_]]{{.}}[[VAR_35_4_]]{{.}} : memref<4335xui8> +// CHECK: [[VAR_45_2_:%.+]] = arith.subf [[VAR_37_2_]], [[LOAD_RES_6_MEM_2_1_]] : f32 +// CHECK-DAG: [[VAR_46_2_:%.+]] = arith.cmpf oeq, [[VAR_45_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = arith.addf [[VAR_37_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.select [[VAR_46_2_]], [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_37_2_]] : f32 +// CHECK-DAG: [[LOAD_RES_4_MEM_1_1_:%.+]] = arith.cmpf oeq, [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.select [[LOAD_RES_4_MEM_1_1_]], [[LOAD_VAR_reshape_MEM_3_1_]], [[VAR_41_2_]] : f32 +// CHECK: [[VAR_51_3_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_1_]], [[VAR_29_]] : f32 +// CHECK: [[VAR_52_3_:%.+]] = arith.maxnumf [[VAR_51_3_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_53_1_:%.+]] = arith.minnumf [[VAR_52_3_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_54_1_:%.+]] = arith.fptoui [[VAR_53_1_]] : f32 to i32 +// CHECK: [[VAR_55_1_:%.+]] = arith.trunci [[VAR_54_1_]] : i32 to i8 +// CHECK: [[VAR_56_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_56_1_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_13, [[RES_]]_14 : memref<255x17xui8>, memref, memref // CHECK: } @@ -454,16 +443,16 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> @@ -509,46 +498,39 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_9_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_]]([[RES_]]_20) : (memref<1x8xui8>, memref<1xindex>) -> memref<8xui8> -// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<8xf32> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__1_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: vector.store [[VAR_54_]], [[RES_10_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: } -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.parallel([[LOOP_2_]]) : !krnl.loop -// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[RES_10_]]{{.}}[[VAR_33_2_]]{{.}} : memref<8xf32> -// CHECK: [[LOAD_VAR_reshape_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_VAR_reshape_MEM_2_]] : f32 to i8 -// CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_VAR_reshape_MEM_1_1_]] : i8 to ui8 -// CHECK: krnl.store [[LOAD_RES_4_MEM_2_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_2_]]{{.}} : memref<8xui8> +// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> +// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> +// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> +// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<8xf32> +// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> +// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> +// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi32> +// CHECK: [[VAR_55_:%.+]] = arith.trunci [[VAR_54_]] : vector<8xi32> to vector<8xi8> +// CHECK: [[VAR_56_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_56_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xui8>, vector<8xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir index ff079672bd..91b68f1adb 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir @@ -60,22 +60,22 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ -// CHECK: [[VAR_13_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_13_]]#0, [[VAR_13_]]#1] : memref +// CHECK: [[VAR_12_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_]]#0, [[VAR_12_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref -// CHECK: [[VAR_16_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_16_]], [[RES_3_]][] : memref +// CHECK: [[VAR_15_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_15_]], [[RES_3_]][] : memref // CHECK: } // CHECK: [[RES_4_:%.+]] = memref.alloc() : memref // CHECK: krnl.memset [[RES_4_]], [[CST_0_1_]] : memref // CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 // CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref // CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ -// CHECK: [[VAR_13_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_13_1_]]#0, [[VAR_13_1_]]#1] : memref +// CHECK: [[VAR_12_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_1_]]#0, [[VAR_12_1_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK: [[VAR_16_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 -// CHECK: krnl.store [[VAR_16_1_]], [[RES_4_]][] : memref +// CHECK: [[VAR_15_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_15_1_]], [[RES_4_]][] : memref // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref @@ -95,39 +95,33 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: affine.store [[VAR_10_]], [[RES_6_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc([[VAR_9_]]) {{.*}}: memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ -// CHECK: [[VAR_13_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_13_2_]]{{.}} : memref +// CHECK: [[VAR_12_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_12_2_]]{{.}} : memref // CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 -// CHECK: [[VAR_16_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 -// CHECK: [[VAR_17_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_16_2_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf ogt, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_16_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_15_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_16_2_]] : f32 -// CHECK-DAG: [[VAR_21_:%.+]] = arith.mulf [[VAR_16_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_22_:%.+]] = math.floor [[VAR_21_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.mulf [[VAR_22_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.subf [[VAR_16_2_]], [[VAR_23_]] : f32 -// CHECK-DAG: [[VAR_25_:%.+]] = arith.cmpf oeq, [[VAR_24_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_26_:%.+]] = arith.addf [[VAR_16_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_2_]], [[VAR_22_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_27_:%.+]] = arith.select [[VAR_25_]], [[VAR_26_]], [[VAR_16_2_]] : f32 -// CHECK-DAG: [[VAR_28_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.select [[VAR_28_]], [[VAR_27_]], [[VAR_20_]] : f32 -// CHECK: [[VAR_30_:%.+]] = arith.maxnumf [[VAR_29_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_31_:%.+]] = arith.minnumf [[VAR_30_]], [[CST_2_dot_550000_]] : f32 -// CHECK: krnl.store [[VAR_31_]], [[RES_7_]]{{.}}[[VAR_13_2_]]{{.}} : memref -// CHECK: } -// CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ -// CHECK: [[VAR_13_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_13_3_]]{{.}} : memref -// CHECK: [[LOAD_RES_3_MEM_1_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_1_]] : f32 to i8 -// CHECK: [[VAR_16_3_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_RES_3_MEM_1_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_16_3_]], [[VAR_reshape_14_]]{{.}}[[VAR_13_3_]]{{.}} : memref +// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 +// CHECK: [[VAR_29_:%.+]] = arith.maxnumf [[VAR_28_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_30_:%.+]] = arith.minnumf [[VAR_29_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_31_:%.+]] = arith.fptoui [[VAR_30_]] : f32 to i32 +// CHECK: [[VAR_32_:%.+]] = arith.trunci [[VAR_31_]] : i32 to i8 +// CHECK: [[VAR_33_:%.+]] = builtin.unrealized_conversion_cast [[VAR_32_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_33_]], [[VAR_reshape_14_]]{{.}}[[VAR_12_2_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref // CHECK: } @@ -150,39 +144,33 @@ func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %a // CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xui8> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<6xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ -// CHECK: [[VAR_3_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_3_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_5_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_6_:%.+]] = math.floor [[VAR_5_]] : f32 -// CHECK: [[VAR_7_:%.+]] = arith.subf [[VAR_5_]], [[VAR_6_]] : f32 -// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpf ogt, [[VAR_7_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_6_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_2_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_4_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_5_:%.+]] = math.floor [[VAR_4_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpf ogt, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_10_:%.+]] = arith.select [[VAR_8_]], [[VAR_9_]], [[VAR_6_]] : f32 -// CHECK-DAG: [[VAR_11_:%.+]] = arith.mulf [[VAR_6_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_12_:%.+]] = math.floor [[VAR_11_]] : f32 -// CHECK: [[VAR_13_:%.+]] = arith.mulf [[VAR_12_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_14_:%.+]] = arith.subf [[VAR_6_]], [[VAR_13_]] : f32 -// CHECK-DAG: [[VAR_15_:%.+]] = arith.cmpf oeq, [[VAR_14_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.addf [[VAR_6_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_7_]], [[VAR_8_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.mulf [[VAR_5_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = math.floor [[VAR_10_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.mulf [[VAR_11_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.subf [[VAR_5_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_:%.+]] = arith.select [[VAR_15_]], [[VAR_16_]], [[VAR_6_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf oeq, [[VAR_7_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_19_:%.+]] = arith.select [[VAR_18_]], [[VAR_17_]], [[VAR_10_]] : f32 -// CHECK: [[VAR_20_:%.+]] = arith.maxnumf [[VAR_19_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = arith.minnumf [[VAR_20_]], [[CST_2_dot_550000_]] : f32 -// CHECK: krnl.store [[VAR_21_]], [[RES_1_]]{{.}}[[VAR_3_]]{{.}} : memref<6xf32> -// CHECK: } -// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 6){ -// CHECK: [[VAR_3_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_3_1_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_5_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_]] : f32 to i8 -// CHECK: [[VAR_6_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_5_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_6_1_]], [[RES_]]{{.}}[[VAR_3_1_]]{{.}} : memref<6xui8> +// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[VAR_15_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[VAR_16_]], [[VAR_9_]] : f32 +// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[VAR_18_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.minnumf [[VAR_19_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.fptoui [[VAR_20_]] : f32 to i32 +// CHECK: [[VAR_22_:%.+]] = arith.trunci [[VAR_21_]] : i32 to i8 +// CHECK: [[VAR_23_:%.+]] = builtin.unrealized_conversion_cast [[VAR_22_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_23_]], [[RES_]]{{.}}[[VAR_2_]]{{.}} : memref<6xui8> // CHECK: } // CHECK: return [[RES_]] : memref<6xui8> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir index 6ec8f9d8cb..8717a0f795 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir @@ -23,40 +23,34 @@ func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %a // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref // CHECK: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 // CHECK-DAG: [[VAR_3_:%.+]] = arith.uitofp [[VAR_2_]] : i8 to f32 -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<6xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ -// CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_6_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_8_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_9_:%.+]] = math.floor [[VAR_8_]] : f32 -// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_8_]], [[VAR_9_]] : f32 -// CHECK-DAG: [[VAR_11_:%.+]] = arith.cmpf ogt, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_12_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_5_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_7_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_8_:%.+]] = math.floor [[VAR_7_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpf ogt, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_13_:%.+]] = arith.select [[VAR_11_]], [[VAR_12_]], [[VAR_9_]] : f32 -// CHECK-DAG: [[VAR_14_:%.+]] = arith.mulf [[VAR_9_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.mulf [[VAR_15_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_17_:%.+]] = arith.subf [[VAR_9_]], [[VAR_16_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_10_]], [[VAR_11_]], [[VAR_8_]] : f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_8_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_14_:%.+]] = math.floor [[VAR_13_]] : f32 +// CHECK: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_8_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_9_]] : f32 -// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf oeq, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.select [[VAR_21_]], [[VAR_20_]], [[VAR_13_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.addf [[VAR_22_]], [[VAR_3_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.maxnumf [[VAR_23_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_25_:%.+]] = arith.minnumf [[VAR_24_]], [[CST_2_dot_550000_]] : f32 -// CHECK: krnl.store [[VAR_25_]], [[RES_1_]]{{.}}[[VAR_6_]]{{.}} : memref<6xf32> -// CHECK: } -// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 6){ -// CHECK: [[VAR_6_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_6_1_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_8_1_:%.+]] = arith.fptoui [[LOAD_PARAM_0_MEM_1_]] : f32 to i8 -// CHECK: [[VAR_9_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_8_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_9_1_]], [[RES_]]{{.}}[[VAR_6_1_]]{{.}} : memref<6xui8> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_8_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpf oeq, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.select [[VAR_20_]], [[VAR_19_]], [[VAR_12_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_3_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[VAR_22_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.minnumf [[VAR_23_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.fptoui [[VAR_24_]] : f32 to i32 +// CHECK: [[VAR_26_:%.+]] = arith.trunci [[VAR_25_]] : i32 to i8 +// CHECK: [[VAR_27_:%.+]] = builtin.unrealized_conversion_cast [[VAR_26_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_27_]], [[RES_]]{{.}}[[VAR_5_]]{{.}} : memref<6xui8> // CHECK: } // CHECK: return [[RES_]] : memref<6xui8> // CHECK: } @@ -82,39 +76,33 @@ func.func @test_quantize_linear_i8(%arg0: tensor<6xf32>, %arg1: tensor, %ar // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref // CHECK-NOT: separator of consecutive DAGs // CHECK-DAG: [[VAR_2_:%.+]] = arith.sitofp [[LOAD_PARAM_2_MEM_]] : i8 to f32 -// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<6xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ -// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_5_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_7_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_8_:%.+]] = math.floor [[VAR_7_]] : f32 -// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 -// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpf ogt, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_11_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_4_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_6_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_7_:%.+]] = math.floor [[VAR_6_]] : f32 +// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_6_]], [[VAR_7_]] : f32 +// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpf ogt, [[VAR_8_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.addf [[VAR_7_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_10_]], [[VAR_11_]], [[VAR_8_]] : f32 -// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_8_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_14_:%.+]] = math.floor [[VAR_13_]] : f32 -// CHECK: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_8_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_9_]], [[VAR_10_]], [[VAR_7_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.mulf [[VAR_7_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_13_:%.+]] = math.floor [[VAR_12_]] : f32 +// CHECK: [[VAR_14_:%.+]] = arith.mulf [[VAR_13_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = arith.subf [[VAR_7_]], [[VAR_14_]] : f32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.cmpf oeq, [[VAR_15_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addf [[VAR_7_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_8_]] : f32 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpf oeq, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = arith.select [[VAR_20_]], [[VAR_19_]], [[VAR_12_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_2_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[VAR_22_]], [[CST_minus_1_dot_280000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.minnumf [[VAR_23_]], [[CST_1_dot_270000_]] : f32 -// CHECK: krnl.store [[VAR_24_]], [[RES_1_]]{{.}}[[VAR_5_]]{{.}} : memref<6xf32> -// CHECK: } -// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 6){ -// CHECK: [[VAR_5_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[RES_1_]]{{.}}[[VAR_5_1_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_7_1_:%.+]] = arith.fptosi [[LOAD_PARAM_0_MEM_1_]] : f32 to i8 -// CHECK: krnl.store [[VAR_7_1_]], [[RES_]]{{.}}[[VAR_5_1_]]{{.}} : memref<6xi8> +// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_16_]], [[VAR_17_]], [[VAR_7_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.cmpf oeq, [[VAR_8_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.select [[VAR_19_]], [[VAR_18_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.addf [[VAR_20_]], [[VAR_2_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.maxnumf [[VAR_21_]], [[CST_minus_1_dot_280000_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.minnumf [[VAR_22_]], [[CST_1_dot_270000_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.fptosi [[VAR_23_]] : f32 to i32 +// CHECK: [[VAR_25_:%.+]] = arith.trunci [[VAR_24_]] : i32 to i8 +// CHECK: krnl.store [[VAR_25_]], [[RES_]]{{.}}[[VAR_4_]]{{.}} : memref<6xi8> // CHECK: } // CHECK: return [[RES_]] : memref<6xi8> // CHECK: } From 0b270e5c96242a32b6f9f722f9e0e87c0155f93b Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Wed, 25 Sep 2024 16:13:52 -0400 Subject: [PATCH 7/9] limit unroll Signed-off-by: Alexandre Eichenberger --- src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 65bba9deb0..463532817e 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -56,7 +56,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, {GenericOps::MulGop, 2}, {GenericOps::SelectGop, 3}, {GenericOps::FloorGop, 2}, {GenericOps::EstimatedVectorRegisterPressure, - 4 /* Little parallelism in code. */}}; + 8 /* Little parallelism in code. */}}; totVL = computeSuitableUnrollFactor(inputType /* use unquantized type*/, innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount, simdOnly); From 2fdd72c49d2db682ad963d895c4ff993eb176f10 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Thu, 26 Sep 2024 14:34:19 -0400 Subject: [PATCH 8/9] moved conversions to math builder cast Signed-off-by: Alexandre Eichenberger --- .../Quantization/QuantizeLinear.cpp | 32 +- src/Dialect/Mlir/DialectBuilder.cpp | 40 ++ ...namicQuantizeLinear_with_canonicalize.mlir | 77 ++-- ...QuantizeLinear_with_simd_canonicalize.mlir | 257 +++++------ ...inear_with_simd_parallel_canonicalize.mlir | 401 +++++++++--------- .../onnx_to_krnl/onnx_lowering_fuse.mlir | 19 +- utils/fixLitTest.py | 1 + 7 files changed, 425 insertions(+), 402 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 463532817e..01293c81e9 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -71,29 +71,6 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, DimsExpr outputAF; outputAF.emplace_back(zero); - Type inputElementType = inputType.getElementType(); - unsigned inputWidth; - if (isa(inputElementType)) - inputWidth = 32; - else if (isa(inputElementType)) - inputWidth = 64; - else - llvm_unreachable("unsupported input type"); - IntegerType quantizedIntType = cast(quantizedElementType); - bool isSigned = quantizedIntType.isSignless() || quantizedIntType.isSigned(); - Type quantizedElementTypeInputSized; - if (isSigned) { - // Cannot use getIntegerType(inputWidth, true) as it returns signed ints. - if (inputWidth == 64) - quantizedElementTypeInputSized = rewriter.getI64Type(); - else if (inputWidth == 32) - quantizedElementTypeInputSized = rewriter.getI32Type(); - else - llvm_unreachable("unsupported input type"); - } else { - // unsigned of the right type - quantizedElementTypeInputSized = rewriter.getIntegerType(inputWidth, false); - } create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, {flatInput}, {inputAF}, {flatAlloc}, {outputAF}, {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { @@ -111,13 +88,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, adjustX = roundX; // Saturate: use max into a min. Value saturateX = create.math.clip(adjustX, qMin, qMax); - // Convert float* to int*/uint* where * is 64/32. - Value qSaturateXInputSized = - create.math.cast(quantizedElementTypeInputSized, saturateX); - // Reduce quantized precision. - Value res = - create.math.cast(quantizedElementType, qSaturateXInputSized); - return res; + // Convert into quantized type. + return create.math.cast(quantizedElementType, saturateX); }}); if (totVL > 1) diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index 1075a25203..ff425513a5 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -886,6 +886,46 @@ Value MathBuilder::cast(Type destType, Value src) const { LLVM_DEBUG(llvm::dbgs() << "srcType: " << srcType << "\n"; llvm::dbgs() << "destType: " << destType << "\n";); + // Before we process with the actual cast, there is a special case that we + // want to handle here. Cast from float to int that have different width, llvm + // generate better patterns if we first cast from float to int of the same + // width, and then from int to a different size int. + // Skip that optimization if the result is a 1 bit (boolean). + if (mlir::isa(srcElemType) && + mlir::isa(destElemType) && bitTrunc && destElemWidth > 1) { + // Quantization: float to smaller int. First determine the intermediary + // type, same integer type as destination type, with the same type width as + // the source float type. + Type step1ElementType; + IntegerType destIntType = mlir::cast(destElemType); + bool destIssSigned = destIntType.isSignless() || destIntType.isSigned(); + if (destIssSigned) + step1ElementType = b().getIntegerType(srcElemWidth); + else + step1ElementType = b().getIntegerType(srcElemWidth, false); + // Perform (recursively) the 2 step conversion. Exceptionally ok here to use + // element type here as cast will promote it to a vector if src is a vector. + Value step1Val = cast(step1ElementType, src); + return cast(destType, step1Val); + } + if (mlir::isa(srcElemType) && + mlir::isa(destElemType) && bitExtend) { + // Dequantization: small int to a float. First determine the intermediary + // type, same integer type as source type, with the same type width as + // the destination float type. + Type step1ElementType; + IntegerType srcIntType = mlir::cast(srcElemType); + bool srcIssSigned = srcIntType.isSignless() || srcIntType.isSigned(); + if (srcIssSigned) + step1ElementType = b().getIntegerType(destElemWidth); + else + step1ElementType = b().getIntegerType(destElemWidth, false); + // Perform (recursively) the 2 step conversion. Exceptionally ok here to use + // element type here as cast will promote it to a vector if src is a vector. + Value step1Val = cast(step1ElementType, src); + return cast(destType, step1Val); + } + // Handle boolean first because they need special handling. // Boolean to int/float conversions. Boolean are unsigned. if (srcElemType.isInteger(1)) { diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir index 022c7ece38..d8dd788672 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_canonicalize.mlir @@ -31,22 +31,22 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor // CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ -// CHECK: [[VAR_31_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_]]#0, [[VAR_31_]]#1] : memref +// CHECK: [[VAR_32_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_]]#0, [[VAR_32_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref -// CHECK: [[VAR_34_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_34_]], [[RES_3_]][] : memref +// CHECK: [[VAR_35_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_35_]], [[RES_3_]][] : memref // CHECK: } // CHECK: [[RES_4_:%.+]] = memref.alloc() : memref // CHECK: krnl.memset [[RES_4_]], [[CST_0_]] : memref // CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 // CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_2_]] : memref // CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ -// CHECK: [[VAR_31_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_31_1_]]#0, [[VAR_31_1_]]#1] : memref +// CHECK: [[VAR_32_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_32_1_]]#0, [[VAR_32_1_]]#1] : memref // CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref -// CHECK: [[VAR_34_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 -// CHECK: krnl.store [[VAR_34_1_]], [[RES_4_]][] : memref +// CHECK: [[VAR_35_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_35_1_]], [[RES_4_]][] : memref // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref @@ -75,46 +75,47 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor -// CHECK: krnl.store [[VAR_27_]], [[RES_2_]][] : memref -// CHECK-DAG: [[VAR_28_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK: krnl.store [[VAR_28_]], [[RES_2_]][] : memref +// CHECK-DAG: [[VAR_29_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> -// CHECK: affine.store [[VAR_28_]], [[RES_5_]][0] : memref<1xindex> +// CHECK: affine.store [[VAR_29_]], [[RES_5_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_5_]]) : (memref, memref<1xindex>) -> memref -// CHECK-DAG: [[VAR_29_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[VAR_30_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} // CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> -// CHECK: affine.store [[VAR_29_]], [[RES_6_]][0] : memref<1xindex> +// CHECK: affine.store [[VAR_30_]], [[RES_6_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ -// CHECK: [[VAR_31_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_31_2_]]{{.}} : memref +// CHECK: [[VAR_32_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_32_2_]]{{.}} : memref // CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 -// CHECK: [[VAR_34_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 -// CHECK: [[VAR_35_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_34_2_]] : f32 -// CHECK-DAG: [[VAR_36_:%.+]] = arith.cmpf ogt, [[VAR_35_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_35_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_36_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_35_2_]] : f32 +// CHECK-DAG: [[VAR_37_:%.+]] = arith.cmpf ogt, [[VAR_36_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_:%.+]] = arith.select [[VAR_36_]], [[VAR_37_]], [[VAR_34_2_]] : f32 -// CHECK-DAG: [[VAR_39_:%.+]] = arith.mulf [[VAR_34_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_40_:%.+]] = math.floor [[VAR_39_]] : f32 -// CHECK: [[VAR_41_:%.+]] = arith.mulf [[VAR_40_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_42_:%.+]] = arith.subf [[VAR_34_2_]], [[VAR_41_]] : f32 -// CHECK-DAG: [[VAR_43_:%.+]] = arith.cmpf oeq, [[VAR_42_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_44_:%.+]] = arith.addf [[VAR_34_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_39_:%.+]] = arith.select [[VAR_37_]], [[VAR_38_]], [[VAR_35_2_]] : f32 +// CHECK-DAG: [[VAR_40_:%.+]] = arith.mulf [[VAR_35_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_41_:%.+]] = math.floor [[VAR_40_]] : f32 +// CHECK: [[VAR_42_:%.+]] = arith.mulf [[VAR_41_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_43_:%.+]] = arith.subf [[VAR_35_2_]], [[VAR_42_]] : f32 +// CHECK-DAG: [[VAR_44_:%.+]] = arith.cmpf oeq, [[VAR_43_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_45_:%.+]] = arith.addf [[VAR_35_2_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_45_:%.+]] = arith.select [[VAR_43_]], [[VAR_44_]], [[VAR_34_2_]] : f32 -// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_35_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_47_:%.+]] = arith.select [[VAR_46_]], [[VAR_45_]], [[VAR_38_]] : f32 -// CHECK: [[VAR_48_:%.+]] = arith.addf [[VAR_47_]], [[VAR_25_]] : f32 -// CHECK: [[VAR_49_:%.+]] = arith.maxnumf [[VAR_48_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_50_:%.+]] = arith.minnumf [[VAR_49_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_51_:%.+]] = arith.fptoui [[VAR_50_]] : f32 to i32 -// CHECK: [[VAR_52_:%.+]] = arith.trunci [[VAR_51_]] : i32 to i8 -// CHECK: [[VAR_53_:%.+]] = builtin.unrealized_conversion_cast [[VAR_52_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_53_]], [[VAR_reshape_14_]]{{.}}[[VAR_31_2_]]{{.}} : memref +// CHECK-DAG: [[VAR_46_:%.+]] = arith.select [[VAR_44_]], [[VAR_45_]], [[VAR_35_2_]] : f32 +// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_36_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_48_:%.+]] = arith.select [[VAR_47_]], [[VAR_46_]], [[VAR_39_]] : f32 +// CHECK: [[VAR_49_:%.+]] = arith.addf [[VAR_48_]], [[VAR_25_]] : f32 +// CHECK: [[VAR_50_:%.+]] = arith.maxnumf [[VAR_49_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_51_:%.+]] = arith.minnumf [[VAR_50_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_52_:%.+]] = arith.fptoui [[VAR_51_]] : f32 to i32 +// CHECK: [[VAR_53_:%.+]] = arith.trunci [[VAR_52_]] : i32 to i8 +// CHECK: [[VAR_54_:%.+]] = builtin.unrealized_conversion_cast [[VAR_53_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_54_]], [[VAR_reshape_14_]]{{.}}[[VAR_32_2_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir index 25f2a50499..b0bea0a414 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_canonicalize.mlir @@ -42,16 +42,16 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<4096xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<4096xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<4096xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_37_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_38_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> @@ -87,10 +87,11 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 // CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i8 -// CHECK: [[VAR_30_:%.+]] = builtin.unrealized_conversion_cast [[VAR_29_]] : i8 to ui8 +// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i32 +// CHECK: [[VAR_30_:%.+]] = arith.trunci [[VAR_29_]] : i32 to i8 +// CHECK: [[VAR_31_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_]] : i8 to ui8 // CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_30_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_31_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_8_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> @@ -100,35 +101,35 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<4096xf32>, vector<16xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<16xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<16xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<16xf32> -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<16xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<16xf32> -// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<16xf32> -// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<16xf32> +// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<16xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<16xf32> -// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<16xf32> -// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<16xf32> to vector<16xi32> -// CHECK: [[VAR_55_:%.+]] = arith.trunci [[VAR_54_]] : vector<16xi32> to vector<16xi8> -// CHECK: [[VAR_56_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_]] : vector<16xi8> to vector<16xui8> -// CHECK: vector.store [[VAR_56_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<4096xui8>, vector<16xui8> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<16xf32> +// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_56_:%.+]] = arith.trunci [[VAR_55_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_1_]]{{.}} : memref<4096xui8>, vector<16xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<256x16xui8>, memref, memref // CHECK: } @@ -173,29 +174,29 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4304){ -// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_34_]]{{.}} : memref<4335xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_34_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK: [[VAR_35_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_35_]]{{.}} : memref<4335xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_35_]]{{.}} : memref<4335xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_39_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_40_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_39_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_40_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_40_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_41_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK: } // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_34_1_]]{{.}} : memref<4335xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_34_1_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_35_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_35_1_]]{{.}} : memref<4335xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_35_1_]]{{.}} : memref<4335xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> -// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: krnl.store [[VAR_40_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> +// CHECK: krnl.store [[VAR_41_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<32xf32>, vector<32xf32> @@ -231,10 +232,11 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[VAR_27_:%.+]] = arith.select [[VAR_25_]], [[VAR_26_]], [[VAR_16_]] : f32 // CHECK-DAG: [[VAR_28_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_29_:%.+]] = arith.select [[VAR_28_]], [[VAR_27_]], [[VAR_20_]] : f32 -// CHECK: [[VAR_30_:%.+]] = arith.fptoui [[VAR_29_]] : f32 to i8 -// CHECK: [[VAR_31_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_]] : i8 to ui8 +// CHECK: [[VAR_30_:%.+]] = arith.fptoui [[VAR_29_]] : f32 to i32 +// CHECK: [[VAR_31_:%.+]] = arith.trunci [[VAR_30_]] : i32 to i8 +// CHECK: [[VAR_32_:%.+]] = builtin.unrealized_conversion_cast [[VAR_31_]] : i8 to ui8 // CHECK: krnl.store [[VAR_11_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_31_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_32_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_8_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> @@ -244,64 +246,64 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 4320){ -// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK: [[VAR_35_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> // CHECK: [[LOAD_RES_4_MEM_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<16xf32> // CHECK: [[LOAD_RES_6_MEM_1_:%.+]] = math.floor [[LOAD_RES_4_MEM_1_]] : vector<16xf32> -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_]], [[LOAD_RES_6_MEM_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK: [[VAR_40_2_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_]], [[LOAD_RES_6_MEM_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.cmpf ogt, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_43_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK: [[VAR_44_:%.+]] = math.floor [[VAR_43_]] : vector<16xf32> -// CHECK: [[VAR_45_:%.+]] = arith.mulf [[VAR_44_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_46_:%.+]] = arith.subf [[LOAD_RES_6_MEM_1_]], [[VAR_45_]] : vector<16xf32> -// CHECK-DAG: [[VAR_47_:%.+]] = arith.cmpf oeq, [[VAR_46_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_43_:%.+]] = arith.select [[VAR_41_2_]], [[VAR_42_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_44_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_45_:%.+]] = math.floor [[VAR_44_]] : vector<16xf32> +// CHECK: [[VAR_46_:%.+]] = arith.mulf [[VAR_45_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_47_:%.+]] = arith.subf [[LOAD_RES_6_MEM_1_]], [[VAR_46_]] : vector<16xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_47_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_]], [[VAR_cst_3_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_47_]], [[VAR_48_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_48_]], [[VAR_49_]], [[LOAD_RES_6_MEM_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.cmpf oeq, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_:%.+]] = arith.select [[VAR_50_]], [[VAR_49_]], [[VAR_42_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_52_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> -// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_]], [[VAR_52_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<16xf32> to vector<16xi32> -// CHECK: [[VAR_57_:%.+]] = arith.trunci [[VAR_56_]] : vector<16xi32> to vector<16xi8> -// CHECK: [[VAR_58_:%.+]] = builtin.unrealized_conversion_cast [[VAR_57_]] : vector<16xi8> to vector<16xui8> -// CHECK: vector.store [[VAR_58_]], [[VAR_reshape_21_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<16xui8> +// CHECK-DAG: [[VAR_52_:%.+]] = arith.select [[VAR_51_]], [[VAR_50_]], [[VAR_43_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_53_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.addf [[VAR_52_]], [[VAR_53_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.maxnumf [[VAR_54_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_56_:%.+]] = arith.minnumf [[VAR_55_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_57_:%.+]] = arith.fptoui [[VAR_56_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_58_:%.+]] = arith.trunci [[VAR_57_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_59_:%.+]] = builtin.unrealized_conversion_cast [[VAR_58_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_59_]], [[VAR_reshape_21_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xui8>, vector<16xui8> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_19_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_35_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = krnl.load [[VAR_reshape_19_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> // CHECK: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_11_]] : f32 // CHECK: [[LOAD_RES_4_MEM_1_1_:%.+]] = math.floor [[LOAD_VAR_reshape_MEM_3_1_]] : f32 // CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_3_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_39_3_:%.+]] = arith.cmpf ogt, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.cmpf ogt, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_41_3_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.select [[VAR_39_3_]], [[VAR_40_3_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_43_1_:%.+]] = math.floor [[VAR_42_1_]] : f32 -// CHECK: [[VAR_44_1_:%.+]] = arith.mulf [[VAR_43_1_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_45_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_1_]], [[VAR_44_1_]] : f32 -// CHECK-DAG: [[VAR_46_1_:%.+]] = arith.cmpf oeq, [[VAR_45_1_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_3_]], [[VAR_41_3_]], [[LOAD_RES_4_MEM_1_1_]] : f32 +// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_44_1_:%.+]] = math.floor [[VAR_43_1_]] : f32 +// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[VAR_44_1_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_1_1_]], [[VAR_45_1_]] : f32 +// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.addf [[LOAD_RES_4_MEM_1_1_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.select [[VAR_46_1_]], [[VAR_47_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 -// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.cmpf oeq, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_50_1_:%.+]] = arith.select [[VAR_49_1_]], [[VAR_48_1_]], [[VAR_41_1_]] : f32 -// CHECK: [[VAR_51_1_:%.+]] = arith.addf [[VAR_50_1_]], [[VAR_29_]] : f32 -// CHECK: [[VAR_52_1_:%.+]] = arith.maxnumf [[VAR_51_1_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_53_1_:%.+]] = arith.minnumf [[VAR_52_1_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_54_1_:%.+]] = arith.fptoui [[VAR_53_1_]] : f32 to i32 -// CHECK: [[VAR_55_1_:%.+]] = arith.trunci [[VAR_54_1_]] : i32 to i8 -// CHECK: [[VAR_56_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_56_1_]], [[VAR_reshape_21_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xui8> +// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.select [[VAR_47_1_]], [[VAR_48_1_]], [[LOAD_RES_4_MEM_1_1_]] : f32 +// CHECK-DAG: [[VAR_50_1_:%.+]] = arith.cmpf oeq, [[LOAD_RES_6_MEM_1_1_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_51_1_:%.+]] = arith.select [[VAR_50_1_]], [[VAR_49_1_]], [[VAR_42_1_]] : f32 +// CHECK: [[VAR_52_1_:%.+]] = arith.addf [[VAR_51_1_]], [[VAR_29_]] : f32 +// CHECK: [[VAR_53_1_:%.+]] = arith.maxnumf [[VAR_52_1_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_54_1_:%.+]] = arith.minnumf [[VAR_53_1_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_55_1_:%.+]] = arith.fptoui [[VAR_54_1_]] : f32 to i32 +// CHECK: [[VAR_56_1_:%.+]] = arith.trunci [[VAR_55_1_]] : i32 to i8 +// CHECK: [[VAR_57_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_57_1_]], [[VAR_reshape_21_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<255x17xui8>, memref, memref // CHECK: } @@ -346,16 +348,16 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_37_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> @@ -391,10 +393,11 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 // CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i8 -// CHECK: [[VAR_30_:%.+]] = builtin.unrealized_conversion_cast [[VAR_29_]] : i8 to ui8 +// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i32 +// CHECK: [[VAR_30_:%.+]] = arith.trunci [[VAR_29_]] : i32 to i8 +// CHECK: [[VAR_31_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_]] : i8 to ui8 // CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_30_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_31_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_8_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> @@ -404,35 +407,35 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ -// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<8xf32> -// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<8xf32> +// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<8xf32> +// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi32> -// CHECK: [[VAR_55_:%.+]] = arith.trunci [[VAR_54_]] : vector<8xi32> to vector<8xi8> -// CHECK: [[VAR_56_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_56_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xui8>, vector<8xui8> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> +// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<8xf32> +// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<8xf32> to vector<8xi32> +// CHECK: [[VAR_56_:%.+]] = arith.trunci [[VAR_55_]] : vector<8xi32> to vector<8xi8> +// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xui8>, vector<8xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir index f73ec71385..14d809207c 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DynamicQuantizeLinear_with_simd_parallel_canonicalize.mlir @@ -49,46 +49,46 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_34_:%.+]] = affine.apply [[MAP_0_]]([[VAR_33_]]) -// CHECK-DAG: [[VAR_35_:%.+]] = affine.min [[MAP_1_]]([[VAR_33_]]) -// CHECK-DAG: [[VAR_36_:%.+]] = affine.apply [[MAP_2_]]([[VAR_33_]]) -// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: [[VAR_37_:%.+]] = affine.min [[MAP_3_]]([[VAR_33_]]) -// CHECK: scf.for [[I_1_:%.+]] = [[VAR_34_]] to [[VAR_37_]] step [[CST_32_]] { +// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_35_:%.+]] = affine.apply [[MAP_0_]]([[VAR_34_]]) +// CHECK-DAG: [[VAR_36_:%.+]] = affine.min [[MAP_1_]]([[VAR_34_]]) +// CHECK-DAG: [[VAR_37_:%.+]] = affine.apply [[MAP_2_]]([[VAR_34_]]) +// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_3_]]([[VAR_34_]]) +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_35_]] to [[VAR_38_]] step [[CST_32_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4096xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_51_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_50_]], [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_51_]], [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_52_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_51_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_52_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK: } -// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_4_]]([[VAR_33_]]) -// CHECK: [[VAR_39_:%.+]] = arith.remsi [[VAR_38_]], [[CST_32_]] : index -// CHECK: [[VAR_40_:%.+]] = arith.subi [[VAR_38_]], [[VAR_39_]] : index -// CHECK: [[VAR_41_:%.+]] = arith.addi [[VAR_34_]], [[VAR_40_]] : index -// CHECK: scf.for [[I_2_:%.+]] = [[VAR_41_]] to [[VAR_35_]] step [[CST_1_]] { +// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_4_]]([[VAR_34_]]) +// CHECK: [[VAR_40_:%.+]] = arith.remsi [[VAR_39_]], [[CST_32_]] : index +// CHECK: [[VAR_41_:%.+]] = arith.subi [[VAR_39_]], [[VAR_40_]] : index +// CHECK: [[VAR_42_:%.+]] = arith.addi [[VAR_35_]], [[VAR_41_]] : index +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_42_]] to [[VAR_36_]] step [[CST_1_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4096xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: memref.store [[VAR_50_1_]], [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> -// CHECK: memref.store [[VAR_51_1_]], [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_51_1_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_52_1_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_36_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_44_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 -// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 -// CHECK: memref.store [[VAR_44_]], [[RES_5_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32> -// CHECK: memref.store [[VAR_45_]], [[RES_7_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_45_]], [[RES_5_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_46_]], [[RES_7_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> // CHECK: } // CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref @@ -96,16 +96,16 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ -// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_34_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32> +// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_34_1_]] : f32 -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_35_1_]] : f32 -// CHECK: krnl.store [[VAR_38_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> -// CHECK: krnl.store [[VAR_39_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_35_1_]] : f32 +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_36_1_]] : f32 +// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> @@ -141,10 +141,11 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK-DAG: [[VAR_27_:%.+]] = arith.select [[VAR_25_]], [[VAR_26_]], [[VAR_16_]] : f32 // CHECK-DAG: [[VAR_28_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_29_:%.+]] = arith.select [[VAR_28_]], [[VAR_27_]], [[VAR_20_]] : f32 -// CHECK: [[VAR_30_:%.+]] = arith.fptoui [[VAR_29_]] : f32 to i8 -// CHECK: [[VAR_31_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_]] : i8 to ui8 +// CHECK: [[VAR_30_:%.+]] = arith.fptoui [[VAR_29_]] : f32 to i32 +// CHECK: [[VAR_31_:%.+]] = arith.trunci [[VAR_30_]] : i32 to i8 +// CHECK: [[VAR_32_:%.+]] = builtin.unrealized_conversion_cast [[VAR_31_]] : i8 to ui8 // CHECK: krnl.store [[VAR_11_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_31_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_32_]], [[RES_2_]][] : memref // CHECK: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4096_]], [[RES_10_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_23_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<256x16xf32>, memref<1xindex>) -> memref<4096xf32> @@ -155,35 +156,35 @@ func.func @test_dynamic_quantize_linear_simd_only(%arg0: tensor<256x16xf32>) -> // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4096){ -// CHECK: [[VAR_33_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_34_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xf32>, vector<16xf32> -// CHECK-DAG: [[VAR_35_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> -// CHECK: [[VAR_36_1_:%.+]] = arith.divf [[VAR_34_1_]], [[VAR_35_2_]] : vector<16xf32> -// CHECK: [[VAR_37_1_:%.+]] = math.floor [[VAR_36_1_]] : vector<16xf32> -// CHECK: [[VAR_38_2_:%.+]] = arith.subf [[VAR_36_1_]], [[VAR_37_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_39_2_:%.+]] = arith.cmpf ogt, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.select [[VAR_39_2_]], [[VAR_40_1_]], [[VAR_37_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_37_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4096xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_35_1_]], [[VAR_36_2_]] : vector<16xf32> +// CHECK: [[VAR_38_1_:%.+]] = math.floor [[VAR_37_1_]] : vector<16xf32> +// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_38_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_1_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> -// CHECK: [[VAR_44_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_45_1_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_44_1_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_45_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_37_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_37_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_38_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_50_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_41_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_51_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> -// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_2_]], [[VAR_51_2_]] : vector<16xf32> -// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<16xf32> to vector<16xi32> -// CHECK: [[VAR_56_:%.+]] = arith.trunci [[VAR_55_]] : vector<16xi32> to vector<16xi8> -// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<16xi8> to vector<16xui8> -// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_25_]]{{.}}[[VAR_33_2_]]{{.}} : memref<4096xui8>, vector<16xui8> +// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_45_1_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_51_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_42_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_52_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_2_]], [[VAR_52_2_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_57_:%.+]] = arith.trunci [[VAR_56_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_58_:%.+]] = builtin.unrealized_conversion_cast [[VAR_57_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_58_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4096xui8>, vector<16xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_13, [[RES_]]_14 : memref<256x16xui8>, memref, memref // CHECK: } @@ -235,46 +236,46 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.parallel([[LOOP_0_]]) : !krnl.loop // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_34_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_:%.+]] = affine.apply [[MAP_0_]]([[VAR_34_]]) -// CHECK-DAG: [[VAR_36_:%.+]] = affine.min [[MAP_1_]]([[VAR_34_]]) -// CHECK-DAG: [[VAR_37_:%.+]] = affine.apply [[MAP_2_]]([[VAR_34_]]) -// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: [[VAR_38_:%.+]] = affine.min [[MAP_3_]]([[VAR_34_]]) -// CHECK: scf.for [[I_1_:%.+]] = [[VAR_35_]] to [[VAR_38_]] step [[CST_32_]] { +// CHECK: [[VAR_35_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_36_:%.+]] = affine.apply [[MAP_0_]]([[VAR_35_]]) +// CHECK-DAG: [[VAR_37_:%.+]] = affine.min [[MAP_1_]]([[VAR_35_]]) +// CHECK-DAG: [[VAR_38_:%.+]] = affine.apply [[MAP_2_]]([[VAR_35_]]) +// CHECK: vector.store [[VAR_cst_7_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_cst_6_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_3_]]([[VAR_35_]]) +// CHECK: scf.for [[I_1_:%.+]] = [[VAR_36_]] to [[VAR_39_]] step [[CST_32_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[I_1_]]{{.}} : memref<4335xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> -// CHECK-DAG: [[VAR_52_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> -// CHECK: vector.store [[VAR_51_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK: vector.store [[VAR_52_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_52_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<32xf32> +// CHECK-DAG: [[VAR_53_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_52_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK: vector.store [[VAR_53_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK: } -// CHECK: [[VAR_39_:%.+]] = affine.min [[MAP_4_]]([[VAR_34_]]) -// CHECK: [[VAR_40_:%.+]] = arith.remsi [[VAR_39_]], [[CST_32_]] : index -// CHECK: [[VAR_41_:%.+]] = arith.subi [[VAR_39_]], [[VAR_40_]] : index -// CHECK: [[VAR_42_:%.+]] = arith.addi [[VAR_35_]], [[VAR_41_]] : index -// CHECK: scf.for [[I_2_:%.+]] = [[VAR_42_]] to [[VAR_36_]] step [[CST_1_]] { +// CHECK: [[VAR_40_:%.+]] = affine.min [[MAP_4_]]([[VAR_35_]]) +// CHECK: [[VAR_41_:%.+]] = arith.remsi [[VAR_40_]], [[CST_32_]] : index +// CHECK: [[VAR_42_:%.+]] = arith.subi [[VAR_40_]], [[VAR_41_]] : index +// CHECK: [[VAR_43_:%.+]] = arith.addi [[VAR_36_]], [[VAR_42_]] : index +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_43_]] to [[VAR_37_]] step [[CST_1_]] { // CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[I_2_]]{{.}} : memref<4335xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = memref.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = memref.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 -// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 -// CHECK: memref.store [[VAR_51_1_]], [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> -// CHECK: memref.store [[VAR_52_1_]], [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_52_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_1_]], [[LOAD_VAR_reshape_MEM_2_]] : f32 +// CHECK-DAG: [[VAR_53_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_1_]], [[LOAD_VAR_reshape_MEM_3_]] : f32 +// CHECK: memref.store [[VAR_52_1_]], [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> +// CHECK: memref.store [[VAR_53_1_]], [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32> // CHECK: } -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_37_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = vector.load [[RES_4_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_2_:%.+]] = vector.load [[RES_6_]]{{.}}[[VAR_38_]]{{.}} : memref<256xf32>, vector<32xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_45_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 -// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 -// CHECK: memref.store [[VAR_45_]], [[RES_5_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> -// CHECK: memref.store [[VAR_46_]], [[RES_7_]]{{.}}[[VAR_34_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = vector.reduction , [[LOAD_RES_4_MEM_2_]] : vector<32xf32> into f32 +// CHECK-DAG: [[VAR_47_:%.+]] = vector.reduction , [[LOAD_RES_6_MEM_2_]] : vector<32xf32> into f32 +// CHECK: memref.store [[VAR_46_]], [[RES_5_]]{{.}}[[VAR_35_]]{{.}} : memref<8xf32> +// CHECK: memref.store [[VAR_47_]], [[RES_7_]]{{.}}[[VAR_35_]]{{.}} : memref<8xf32> // CHECK: } // CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() : memref // CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() : memref @@ -282,16 +283,16 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: vector.store [[VAR_cst_4_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 8){ -// CHECK: [[VAR_34_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> -// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_34_1_]]{{.}} : memref<8xf32> +// CHECK: [[VAR_35_1_:%.+]] = krnl.get_induction_var_value([[LOOP_1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_36_1_:%.+]] = krnl.load [[RES_5_]]{{.}}[[VAR_35_1_]]{{.}} : memref<8xf32> +// CHECK-DAG: [[VAR_37_1_:%.+]] = krnl.load [[RES_7_]]{{.}}[[VAR_35_1_]]{{.}} : memref<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_3_:%.+]] = krnl.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_3_:%.+]] = krnl.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_35_1_]] : f32 -// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_36_1_]] : f32 -// CHECK: krnl.store [[VAR_39_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> -// CHECK: krnl.store [[VAR_40_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK-DAG: [[VAR_40_1_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_3_]], [[VAR_36_1_]] : f32 +// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_3_]], [[VAR_37_1_]] : f32 +// CHECK: krnl.store [[VAR_40_1_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> +// CHECK: krnl.store [[VAR_41_1_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_4_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_4_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<256xf32>, vector<1xf32> @@ -327,10 +328,11 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK-DAG: [[VAR_27_:%.+]] = arith.select [[VAR_25_]], [[VAR_26_]], [[VAR_16_]] : f32 // CHECK-DAG: [[VAR_28_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_29_:%.+]] = arith.select [[VAR_28_]], [[VAR_27_]], [[VAR_20_]] : f32 -// CHECK: [[VAR_30_:%.+]] = arith.fptoui [[VAR_29_]] : f32 to i8 -// CHECK: [[VAR_31_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_]] : i8 to ui8 +// CHECK: [[VAR_30_:%.+]] = arith.fptoui [[VAR_29_]] : f32 to i32 +// CHECK: [[VAR_31_:%.+]] = arith.trunci [[VAR_30_]] : i32 to i8 +// CHECK: [[VAR_32_:%.+]] = builtin.unrealized_conversion_cast [[VAR_31_]] : i8 to ui8 // CHECK: krnl.store [[VAR_11_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_31_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_32_]], [[RES_2_]][] : memref // CHECK: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_4335_]], [[RES_10_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_23_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<255x17xf32>, memref<1xindex>) -> memref<4335xf32> @@ -341,64 +343,64 @@ func.func @test_dynamic_quantize_linear_simd_and_scalar(%arg0: tensor<255x17xf32 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_2_]] 16 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__0_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 4320){ -// CHECK: [[VAR_34_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xf32>, vector<16xf32> -// CHECK-DAG: [[VAR_36_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_35_1_]], [[VAR_36_2_]] : vector<16xf32> -// CHECK: [[VAR_38_1_:%.+]] = math.floor [[VAR_37_1_]] : vector<16xf32> -// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_1_]], [[VAR_38_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_40_2_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-DAG: [[VAR_41_1_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.select [[VAR_40_2_]], [[VAR_41_1_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_38_1_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK: [[VAR_35_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_36_1_:%.+]] = vector.load [[VAR_reshape_23_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xf32>, vector<16xf32> +// CHECK-DAG: [[VAR_37_2_:%.+]] = vector.splat [[VAR_11_]] : vector<16xf32> +// CHECK: [[VAR_38_1_:%.+]] = arith.divf [[VAR_36_1_]], [[VAR_37_2_]] : vector<16xf32> +// CHECK: [[VAR_39_1_:%.+]] = math.floor [[VAR_38_1_]] : vector<16xf32> +// CHECK: [[VAR_40_2_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_39_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.cmpf ogt, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-DAG: [[VAR_42_1_:%.+]] = arith.addf [[VAR_39_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_43_1_:%.+]] = arith.select [[VAR_41_2_]], [[VAR_42_1_]], [[VAR_39_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_4_MEM_2_:%.+]] = arith.mulf [[VAR_39_1_]], [[VAR_cst_1_]] : vector<16xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<16xf32> -// CHECK: [[VAR_45_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> -// CHECK: [[VAR_46_1_:%.+]] = arith.subf [[VAR_38_1_]], [[VAR_45_1_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_46_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_38_1_]], [[VAR_cst_3_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_38_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[VAR_cst_1_]] : vector<16xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_51_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_42_1_]] : vector<16xi1>, vector<16xf32> -// CHECK-DAG: [[VAR_52_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> -// CHECK: [[VAR_53_:%.+]] = arith.addf [[VAR_51_2_]], [[VAR_52_2_]] : vector<16xf32> -// CHECK: [[VAR_54_:%.+]] = arith.maxnumf [[VAR_53_]], [[VAR_cst_0_]] : vector<16xf32> -// CHECK: [[VAR_55_:%.+]] = arith.minnumf [[VAR_54_]], [[VAR_cst_]] : vector<16xf32> -// CHECK: [[VAR_56_:%.+]] = arith.fptoui [[VAR_55_]] : vector<16xf32> to vector<16xi32> -// CHECK: [[VAR_57_:%.+]] = arith.trunci [[VAR_56_]] : vector<16xi32> to vector<16xi8> -// CHECK: [[VAR_58_:%.+]] = builtin.unrealized_conversion_cast [[VAR_57_]] : vector<16xi8> to vector<16xui8> -// CHECK: vector.store [[VAR_58_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_2_]]{{.}} : memref<4335xui8>, vector<16xui8> +// CHECK: [[VAR_46_1_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_2_]] : vector<16xf32> +// CHECK: [[VAR_47_1_:%.+]] = arith.subf [[VAR_39_1_]], [[VAR_46_1_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = arith.cmpf oeq, [[VAR_47_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = arith.addf [[VAR_39_1_]], [[VAR_cst_3_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = arith.select [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_3_]], [[VAR_39_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = arith.cmpf oeq, [[VAR_40_2_]], [[VAR_cst_1_]] : vector<16xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_52_2_:%.+]] = arith.select [[LOAD_RES_6_MEM_1_]], [[LOAD_RES_4_MEM_1_]], [[VAR_43_1_]] : vector<16xi1>, vector<16xf32> +// CHECK-DAG: [[VAR_53_2_:%.+]] = vector.splat [[VAR_29_]] : vector<16xf32> +// CHECK: [[VAR_54_:%.+]] = arith.addf [[VAR_52_2_]], [[VAR_53_2_]] : vector<16xf32> +// CHECK: [[VAR_55_:%.+]] = arith.maxnumf [[VAR_54_]], [[VAR_cst_0_]] : vector<16xf32> +// CHECK: [[VAR_56_:%.+]] = arith.minnumf [[VAR_55_]], [[VAR_cst_]] : vector<16xf32> +// CHECK: [[VAR_57_:%.+]] = arith.fptoui [[VAR_56_]] : vector<16xf32> to vector<16xi32> +// CHECK: [[VAR_58_:%.+]] = arith.trunci [[VAR_57_]] : vector<16xi32> to vector<16xi8> +// CHECK: [[VAR_59_:%.+]] = builtin.unrealized_conversion_cast [[VAR_58_]] : vector<16xi8> to vector<16xui8> +// CHECK: vector.store [[VAR_59_]], [[VAR_reshape_25_]]{{.}}[[VAR_35_2_]]{{.}} : memref<4335xui8>, vector<16xui8> // CHECK: } // CHECK: [[LOOP_3_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_5_:%.+]] = 4320 to 4335){ -// CHECK: [[VAR_34_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index -// CHECK: [[VAR_35_1_1_:%.+]] = krnl.load [[VAR_reshape_23_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xf32> -// CHECK: [[VAR_36_3_:%.+]] = arith.divf [[VAR_35_1_1_]], [[VAR_11_]] : f32 -// CHECK: [[VAR_37_2_:%.+]] = math.floor [[VAR_36_3_]] : f32 -// CHECK: [[VAR_38_2_:%.+]] = arith.subf [[VAR_36_3_]], [[VAR_37_2_]] : f32 -// CHECK-DAG: [[VAR_39_3_:%.+]] = arith.cmpf ogt, [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.addf [[VAR_37_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_41_2_:%.+]] = arith.select [[VAR_39_3_]], [[VAR_40_3_]], [[VAR_37_2_]] : f32 -// CHECK-DAG: [[VAR_42_2_:%.+]] = arith.mulf [[VAR_37_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = math.floor [[VAR_42_2_]] : f32 +// CHECK: [[VAR_35_3_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index +// CHECK: [[VAR_36_1_1_:%.+]] = krnl.load [[VAR_reshape_23_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xf32> +// CHECK: [[VAR_37_3_:%.+]] = arith.divf [[VAR_36_1_1_]], [[VAR_11_]] : f32 +// CHECK: [[VAR_38_2_:%.+]] = math.floor [[VAR_37_3_]] : f32 +// CHECK: [[VAR_39_2_:%.+]] = arith.subf [[VAR_37_3_]], [[VAR_38_2_]] : f32 +// CHECK-DAG: [[VAR_40_3_:%.+]] = arith.cmpf ogt, [[VAR_39_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_41_3_:%.+]] = arith.addf [[VAR_38_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_42_2_:%.+]] = arith.select [[VAR_40_3_]], [[VAR_41_3_]], [[VAR_38_2_]] : f32 +// CHECK-DAG: [[VAR_43_2_:%.+]] = arith.mulf [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[LOAD_RES_4_MEM_2_1_:%.+]] = math.floor [[VAR_43_2_]] : f32 // CHECK: [[LOAD_RES_6_MEM_2_1_:%.+]] = arith.mulf [[LOAD_RES_4_MEM_2_1_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_45_2_:%.+]] = arith.subf [[VAR_37_2_]], [[LOAD_RES_6_MEM_2_1_]] : f32 -// CHECK-DAG: [[VAR_46_2_:%.+]] = arith.cmpf oeq, [[VAR_45_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = arith.addf [[VAR_37_2_]], [[CST_1_dot_000000_]] : f32 -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.select [[VAR_46_2_]], [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_37_2_]] : f32 -// CHECK-DAG: [[LOAD_RES_4_MEM_1_1_:%.+]] = arith.cmpf oeq, [[VAR_38_2_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.select [[LOAD_RES_4_MEM_1_1_]], [[LOAD_VAR_reshape_MEM_3_1_]], [[VAR_41_2_]] : f32 -// CHECK: [[VAR_51_3_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_1_]], [[VAR_29_]] : f32 -// CHECK: [[VAR_52_3_:%.+]] = arith.maxnumf [[VAR_51_3_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_53_1_:%.+]] = arith.minnumf [[VAR_52_3_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_54_1_:%.+]] = arith.fptoui [[VAR_53_1_]] : f32 to i32 -// CHECK: [[VAR_55_1_:%.+]] = arith.trunci [[VAR_54_1_]] : i32 to i8 -// CHECK: [[VAR_56_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_1_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_56_1_]], [[VAR_reshape_25_]]{{.}}[[VAR_34_3_]]{{.}} : memref<4335xui8> +// CHECK: [[VAR_46_2_:%.+]] = arith.subf [[VAR_38_2_]], [[LOAD_RES_6_MEM_2_1_]] : f32 +// CHECK-DAG: [[VAR_47_2_:%.+]] = arith.cmpf oeq, [[VAR_46_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_1_:%.+]] = arith.addf [[VAR_38_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_1_:%.+]] = arith.select [[VAR_47_2_]], [[LOAD_VAR_reshape_MEM_2_1_]], [[VAR_38_2_]] : f32 +// CHECK-DAG: [[LOAD_RES_4_MEM_1_1_:%.+]] = arith.cmpf oeq, [[VAR_39_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[LOAD_RES_6_MEM_1_1_:%.+]] = arith.select [[LOAD_RES_4_MEM_1_1_]], [[LOAD_VAR_reshape_MEM_3_1_]], [[VAR_42_2_]] : f32 +// CHECK: [[VAR_52_3_:%.+]] = arith.addf [[LOAD_RES_6_MEM_1_1_]], [[VAR_29_]] : f32 +// CHECK: [[VAR_53_3_:%.+]] = arith.maxnumf [[VAR_52_3_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_54_1_:%.+]] = arith.minnumf [[VAR_53_3_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_55_1_:%.+]] = arith.fptoui [[VAR_54_1_]] : f32 to i32 +// CHECK: [[VAR_56_1_:%.+]] = arith.trunci [[VAR_55_1_]] : i32 to i8 +// CHECK: [[VAR_57_1_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_1_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_57_1_]], [[VAR_reshape_25_]]{{.}}[[VAR_35_3_]]{{.}} : memref<4335xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_13, [[RES_]]_14 : memref<255x17xui8>, memref, memref // CHECK: } @@ -443,16 +445,16 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 8){ -// CHECK: [[VAR_32_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_32_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_33_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_33_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_37_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.minnumf [[LOAD_RES_4_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.maxnumf [[LOAD_RES_6_MEM_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_4_MEM_1_:%.+]] = vector.load [[RES_4_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_RES_6_MEM_1_:%.+]] = vector.load [[RES_6_]]{{.}}[[CST_0_]]{{.}} : memref<8xf32>, vector<8xf32> @@ -488,10 +490,11 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_]] : f32 // CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 // CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 -// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i8 -// CHECK: [[VAR_30_:%.+]] = builtin.unrealized_conversion_cast [[VAR_29_]] : i8 to ui8 +// CHECK: [[VAR_29_:%.+]] = arith.fptoui [[VAR_28_]] : f32 to i32 +// CHECK: [[VAR_30_:%.+]] = arith.trunci [[VAR_29_]] : i32 to i8 +// CHECK: [[VAR_31_:%.+]] = builtin.unrealized_conversion_cast [[VAR_30_]] : i8 to ui8 // CHECK: krnl.store [[VAR_10_]], [[RES_1_]][] : memref -// CHECK: krnl.store [[VAR_30_]], [[RES_2_]][] : memref +// CHECK: krnl.store [[VAR_31_]], [[RES_2_]][] : memref // CHECK: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> // CHECK: affine.store [[CST_8_]], [[RES_8_]][0] : memref<1xindex> // CHECK-DAG: [[VAR_reshape_19_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_8_]]) : (memref<1x8xf32>, memref<1xindex>) -> memref<8xf32> @@ -502,35 +505,35 @@ func.func @test_dynamic_quantize_linear_reduced_simd_only(%arg0: tensor<1x8xf32> // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 8 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.parallel([[BLOCK_TILE__1_]]) : !krnl.loop // CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = 0 to 8){ -// CHECK: [[VAR_32_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xf32>, vector<8xf32> +// CHECK: [[VAR_33_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_19_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xf32>, vector<8xf32> // CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.splat [[VAR_10_]] : vector<8xf32> // CHECK: [[LOAD_RES_4_MEM_2_:%.+]] = arith.divf [[LOAD_VAR_reshape_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<8xf32> // CHECK: [[LOAD_RES_6_MEM_2_:%.+]] = math.floor [[LOAD_RES_4_MEM_2_]] : vector<8xf32> -// CHECK: [[VAR_37_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> -// CHECK-DAG: [[VAR_38_1_:%.+]] = arith.cmpf ogt, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_40_:%.+]] = arith.select [[VAR_38_1_]], [[VAR_39_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_41_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK: [[VAR_42_:%.+]] = math.floor [[VAR_41_]] : vector<8xf32> -// CHECK: [[VAR_43_:%.+]] = arith.mulf [[VAR_42_]], [[VAR_cst_2_]] : vector<8xf32> -// CHECK: [[VAR_44_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_43_]] : vector<8xf32> -// CHECK-DAG: [[VAR_45_:%.+]] = arith.cmpf oeq, [[VAR_44_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_47_:%.+]] = arith.select [[VAR_45_]], [[VAR_46_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_48_:%.+]] = arith.cmpf oeq, [[VAR_37_1_]], [[VAR_cst_1_]] : vector<8xf32> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_49_:%.+]] = arith.select [[VAR_48_]], [[VAR_47_]], [[VAR_40_]] : vector<8xi1>, vector<8xf32> -// CHECK-DAG: [[VAR_50_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> -// CHECK: [[VAR_51_:%.+]] = arith.addf [[VAR_49_]], [[VAR_50_]] : vector<8xf32> -// CHECK: [[VAR_52_:%.+]] = arith.maxnumf [[VAR_51_]], [[VAR_cst_0_]] : vector<8xf32> -// CHECK: [[VAR_53_:%.+]] = arith.minnumf [[VAR_52_]], [[VAR_cst_]] : vector<8xf32> -// CHECK: [[VAR_54_:%.+]] = arith.fptoui [[VAR_53_]] : vector<8xf32> to vector<8xi32> -// CHECK: [[VAR_55_:%.+]] = arith.trunci [[VAR_54_]] : vector<8xi32> to vector<8xi8> -// CHECK: [[VAR_56_:%.+]] = builtin.unrealized_conversion_cast [[VAR_55_]] : vector<8xi8> to vector<8xui8> -// CHECK: vector.store [[VAR_56_]], [[VAR_reshape_21_]]{{.}}[[VAR_32_1_]]{{.}} : memref<8xui8>, vector<8xui8> +// CHECK: [[VAR_38_1_:%.+]] = arith.subf [[LOAD_RES_4_MEM_2_]], [[LOAD_RES_6_MEM_2_]] : vector<8xf32> +// CHECK-DAG: [[VAR_39_1_:%.+]] = arith.cmpf ogt, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-DAG: [[VAR_40_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_41_:%.+]] = arith.select [[VAR_39_1_]], [[VAR_40_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_42_:%.+]] = arith.mulf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK: [[VAR_43_:%.+]] = math.floor [[VAR_42_]] : vector<8xf32> +// CHECK: [[VAR_44_:%.+]] = arith.mulf [[VAR_43_]], [[VAR_cst_2_]] : vector<8xf32> +// CHECK: [[VAR_45_:%.+]] = arith.subf [[LOAD_RES_6_MEM_2_]], [[VAR_44_]] : vector<8xf32> +// CHECK-DAG: [[VAR_46_:%.+]] = arith.cmpf oeq, [[VAR_45_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_6_MEM_2_]], [[VAR_cst_3_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_48_:%.+]] = arith.select [[VAR_46_]], [[VAR_47_]], [[LOAD_RES_6_MEM_2_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.cmpf oeq, [[VAR_38_1_]], [[VAR_cst_1_]] : vector<8xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_50_:%.+]] = arith.select [[VAR_49_]], [[VAR_48_]], [[VAR_41_]] : vector<8xi1>, vector<8xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = vector.splat [[VAR_28_]] : vector<8xf32> +// CHECK: [[VAR_52_:%.+]] = arith.addf [[VAR_50_]], [[VAR_51_]] : vector<8xf32> +// CHECK: [[VAR_53_:%.+]] = arith.maxnumf [[VAR_52_]], [[VAR_cst_0_]] : vector<8xf32> +// CHECK: [[VAR_54_:%.+]] = arith.minnumf [[VAR_53_]], [[VAR_cst_]] : vector<8xf32> +// CHECK: [[VAR_55_:%.+]] = arith.fptoui [[VAR_54_]] : vector<8xf32> to vector<8xi32> +// CHECK: [[VAR_56_:%.+]] = arith.trunci [[VAR_55_]] : vector<8xi32> to vector<8xi8> +// CHECK: [[VAR_57_:%.+]] = builtin.unrealized_conversion_cast [[VAR_56_]] : vector<8xi8> to vector<8xui8> +// CHECK: vector.store [[VAR_57_]], [[VAR_reshape_21_]]{{.}}[[VAR_33_1_]]{{.}} : memref<8xui8>, vector<8xui8> // CHECK: } // CHECK: return [[RES_]], [[RES_]]_11, [[RES_]]_12 : memref<1x8xui8>, memref, memref // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir index 5309b274fd..f319459c68 100644 --- a/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir +++ b/test/mlir/conversion/onnx_to_krnl/onnx_lowering_fuse.mlir @@ -124,8 +124,9 @@ func.func @test_fuse_element8(%arg0: tensor, %arg1: tensor<1xf32>) -> ten // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_0_]]{{.}} : memref<1xf32> // CHECK: [[VAR_4_:%.+]] = math.powf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_5_:%.+]] = arith.fptosi [[VAR_4_]] : f32 to i8 -// CHECK: krnl.store [[VAR_5_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref +// CHECK: [[VAR_5_:%.+]] = arith.fptosi [[VAR_4_]] : f32 to i32 +// CHECK: [[VAR_6_:%.+]] = arith.trunci [[VAR_5_]] : i32 to i8 +// CHECK: krnl.store [[VAR_6_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -322,13 +323,14 @@ func.func @fuse_element_20(%533: tensor, %537 : tensor, // ----- + func.func @test_fuse_element21(%arg0: tensor, %arg1: tensor<1xf32>, %arg2 : tensor<1xi8>) -> tensor { - %0 = "onnx.Pow"(%arg0, %arg1) : (tensor, tensor<1xf32>) -> tensor + %0 = "onnx.Pow"(%arg0, %arg1) : (tensor, tensor<1xf32>) -> tensor %1 = "onnx.Cast"(%0) {to = i8} : (tensor) -> tensor %2 = "onnx.Add"(%1, %arg2) : (tensor, tensor<1xi8>) -> tensor return %2 : tensor -} +// mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @test_fuse_element21 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref, [[PARAM_1_:%.+]]: memref<1xf32>, [[PARAM_2_:%.+]]: memref<1xi8>) -> memref { @@ -341,12 +343,13 @@ func.func @test_fuse_element21(%arg0: tensor, %arg1: tensor<1xf32>, %arg2 // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[CST_0_]]{{.}} : memref<1xf32> // CHECK: [[VAR_4_:%.+]] = math.powf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK-DAG: [[VAR_5_:%.+]] = arith.fptosi [[VAR_4_]] : f32 to i8 +// CHECK: [[VAR_5_:%.+]] = arith.fptosi [[VAR_4_]] : f32 to i32 +// CHECK-DAG: [[VAR_6_:%.+]] = arith.trunci [[VAR_5_]] : i32 to i8 // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]]{{.}}[[CST_0_]]{{.}} : memref<1xi8> -// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_5_]], [[LOAD_PARAM_2_MEM_]] : i8 -// CHECK: krnl.store [[VAR_7_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_6_]], [[LOAD_PARAM_2_MEM_]] : i8 +// CHECK: krnl.store [[VAR_8_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } - +} diff --git a/utils/fixLitTest.py b/utils/fixLitTest.py index 717f2c3567..b9e5ab3b46 100755 --- a/utils/fixLitTest.py +++ b/utils/fixLitTest.py @@ -425,6 +425,7 @@ def main(argv): dprint("\n>> Tested with " + str(test_error_num) + " errors:") for f in test_error_functions: dprint(">> " + f) + dprint(">> Completed processing of " + lit_test_filename + "\n") if __name__ == "__main__": From 28863c82a49bcbbde5e07798cfad677f17e16abe Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Thu, 26 Sep 2024 14:46:15 -0400 Subject: [PATCH 9/9] update lit tests Signed-off-by: Alexandre Eichenberger --- .../DequantizeLinear_with_canonicalize.mlir | 29 ++--- .../QuantizationWithoutZeroPoint.mlir | 7 +- .../QuantizeLinear_with_canonicalize.mlir | 101 +++++++++--------- 3 files changed, 72 insertions(+), 65 deletions(-) diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir index 93d38fc77a..4873f5a1a4 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir @@ -5,10 +5,12 @@ // ----- + func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> +// mlir2FileCheck.py // CHECK-LABEL: func.func @test_dequantizelinear_i8 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xi8>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<4xf32> { // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32> @@ -18,12 +20,13 @@ func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor, %ar // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xi8> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_5_:%.+]] = arith.sitofp [[LOAD_PARAM_2_MEM_]] : i8 to f32 -// CHECK-DAG: [[VAR_6_:%.+]] = arith.sitofp [[LOAD_PARAM_0_MEM_]] : i8 to f32 -// CHECK: [[VAR_7_:%.+]] = arith.subf [[VAR_6_]], [[VAR_5_]] : f32 -// CHECK: [[VAR_8_:%.+]] = arith.mulf [[VAR_7_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: krnl.store [[VAR_8_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> +// CHECK: [[VAR_5_:%.+]] = arith.extsi [[LOAD_PARAM_0_MEM_]] : i8 to i32 +// CHECK-DAG: [[VAR_6_:%.+]] = arith.sitofp [[VAR_5_]] : i32 to f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.extsi [[LOAD_PARAM_2_MEM_]] : i8 to i32 +// CHECK: [[VAR_8_:%.+]] = arith.sitofp [[VAR_7_]] : i32 to f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref<4xf32> // CHECK: } @@ -47,12 +50,14 @@ func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, % // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref // CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 -// CHECK-DAG: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i8 to f32 -// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 -// CHECK: [[VAR_8_:%.+]] = arith.uitofp [[VAR_7_]] : i8 to f32 -// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32 -// CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> +// CHECK: [[VAR_6_:%.+]] = arith.extui [[VAR_5_]] : i8 to i32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.uitofp [[VAR_6_]] : i32 to f32 +// CHECK-DAG: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 +// CHECK: [[VAR_9_:%.+]] = arith.extui [[VAR_8_]] : i8 to i32 +// CHECK: [[VAR_10_:%.+]] = arith.uitofp [[VAR_9_]] : i32 to f32 +// CHECK: [[VAR_11_:%.+]] = arith.subf [[VAR_7_]], [[VAR_10_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.mulf [[VAR_11_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_12_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref<4xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir index 91b68f1adb..15fedeab1b 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir @@ -22,9 +22,10 @@ func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, % // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xui8> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK: [[VAR_4_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 -// CHECK: [[VAR_5_:%.+]] = arith.uitofp [[VAR_4_]] : i8 to f32 -// CHECK: [[VAR_6_:%.+]] = arith.mulf [[VAR_5_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: krnl.store [[VAR_6_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> +// CHECK: [[VAR_5_:%.+]] = arith.extui [[VAR_4_]] : i8 to i32 +// CHECK: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i32 to f32 +// CHECK: [[VAR_7_:%.+]] = arith.mulf [[VAR_6_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_7_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref<4xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir index 8717a0f795..42d7cad73e 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizeLinear_with_canonicalize.mlir @@ -22,35 +22,36 @@ func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %a // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref // CHECK: [[VAR_2_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 -// CHECK-DAG: [[VAR_3_:%.+]] = arith.uitofp [[VAR_2_]] : i8 to f32 +// CHECK: [[VAR_3_:%.+]] = arith.extui [[VAR_2_]] : i8 to i32 +// CHECK-DAG: [[VAR_4_:%.+]] = arith.uitofp [[VAR_3_]] : i32 to f32 // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ -// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_5_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_7_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_8_:%.+]] = math.floor [[VAR_7_]] : f32 -// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 -// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpf ogt, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_11_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_6_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_8_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_9_:%.+]] = math.floor [[VAR_8_]] : f32 +// CHECK: [[VAR_10_:%.+]] = arith.subf [[VAR_8_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.cmpf ogt, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_10_]], [[VAR_11_]], [[VAR_8_]] : f32 -// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_8_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_14_:%.+]] = math.floor [[VAR_13_]] : f32 -// CHECK: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_8_]], [[VAR_15_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.select [[VAR_11_]], [[VAR_12_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.mulf [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_15_:%.+]] = math.floor [[VAR_14_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.mulf [[VAR_15_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_17_:%.+]] = arith.subf [[VAR_9_]], [[VAR_16_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.cmpf oeq, [[VAR_17_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_9_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_8_]] : f32 -// CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpf oeq, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_21_:%.+]] = arith.select [[VAR_20_]], [[VAR_19_]], [[VAR_12_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_3_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[VAR_22_]], [[CST_0_dot_000000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.minnumf [[VAR_23_]], [[CST_2_dot_550000_]] : f32 -// CHECK: [[VAR_25_:%.+]] = arith.fptoui [[VAR_24_]] : f32 to i32 -// CHECK: [[VAR_26_:%.+]] = arith.trunci [[VAR_25_]] : i32 to i8 -// CHECK: [[VAR_27_:%.+]] = builtin.unrealized_conversion_cast [[VAR_26_]] : i8 to ui8 -// CHECK: krnl.store [[VAR_27_]], [[RES_]]{{.}}[[VAR_5_]]{{.}} : memref<6xui8> +// CHECK-DAG: [[VAR_20_:%.+]] = arith.select [[VAR_18_]], [[VAR_19_]], [[VAR_9_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.cmpf oeq, [[VAR_10_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.select [[VAR_21_]], [[VAR_20_]], [[VAR_13_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.addf [[VAR_22_]], [[VAR_4_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.maxnumf [[VAR_23_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.minnumf [[VAR_24_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_26_:%.+]] = arith.fptoui [[VAR_25_]] : f32 to i32 +// CHECK: [[VAR_27_:%.+]] = arith.trunci [[VAR_26_]] : i32 to i8 +// CHECK: [[VAR_28_:%.+]] = builtin.unrealized_conversion_cast [[VAR_27_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_28_]], [[RES_]]{{.}}[[VAR_6_]]{{.}} : memref<6xui8> // CHECK: } // CHECK: return [[RES_]] : memref<6xui8> // CHECK: } @@ -74,35 +75,35 @@ func.func @test_quantize_linear_i8(%arg0: tensor<6xf32>, %arg1: tensor, %ar // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xi8> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = arith.sitofp [[LOAD_PARAM_2_MEM_]] : i8 to f32 +// CHECK: [[VAR_2_:%.+]] = arith.extsi [[LOAD_PARAM_2_MEM_]] : i8 to i32 +// CHECK-DAG: [[VAR_3_:%.+]] = arith.sitofp [[VAR_2_]] : i32 to f32 // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ -// CHECK: [[VAR_4_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index -// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_4_]]{{.}} : memref<6xf32> -// CHECK: [[VAR_6_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 -// CHECK: [[VAR_7_:%.+]] = math.floor [[VAR_6_]] : f32 -// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_6_]], [[VAR_7_]] : f32 -// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpf ogt, [[VAR_8_]], [[CST_5_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_10_:%.+]] = arith.addf [[VAR_7_]], [[CST_1_dot_000000_]] : f32 +// CHECK: [[VAR_5_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_5_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_7_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_8_:%.+]] = math.floor [[VAR_7_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_7_]], [[VAR_8_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.cmpf ogt, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_11_:%.+]] = arith.select [[VAR_9_]], [[VAR_10_]], [[VAR_7_]] : f32 -// CHECK-DAG: [[VAR_12_:%.+]] = arith.mulf [[VAR_7_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_13_:%.+]] = math.floor [[VAR_12_]] : f32 -// CHECK: [[VAR_14_:%.+]] = arith.mulf [[VAR_13_]], [[CST_2_dot_000000_]] : f32 -// CHECK: [[VAR_15_:%.+]] = arith.subf [[VAR_7_]], [[VAR_14_]] : f32 -// CHECK-DAG: [[VAR_16_:%.+]] = arith.cmpf oeq, [[VAR_15_]], [[CST_1_dot_000000_]] : f32 -// CHECK-DAG: [[VAR_17_:%.+]] = arith.addf [[VAR_7_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_12_:%.+]] = arith.select [[VAR_10_]], [[VAR_11_]], [[VAR_8_]] : f32 +// CHECK-DAG: [[VAR_13_:%.+]] = arith.mulf [[VAR_8_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_14_:%.+]] = math.floor [[VAR_13_]] : f32 +// CHECK: [[VAR_15_:%.+]] = arith.mulf [[VAR_14_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.subf [[VAR_8_]], [[VAR_15_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_8_]], [[CST_1_dot_000000_]] : f32 // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_18_:%.+]] = arith.select [[VAR_16_]], [[VAR_17_]], [[VAR_7_]] : f32 -// CHECK-DAG: [[VAR_19_:%.+]] = arith.cmpf oeq, [[VAR_8_]], [[CST_5_dot_000000_]] : f32 -// CHECK: [[VAR_20_:%.+]] = arith.select [[VAR_19_]], [[VAR_18_]], [[VAR_11_]] : f32 -// CHECK: [[VAR_21_:%.+]] = arith.addf [[VAR_20_]], [[VAR_2_]] : f32 -// CHECK: [[VAR_22_:%.+]] = arith.maxnumf [[VAR_21_]], [[CST_minus_1_dot_280000_]] : f32 -// CHECK: [[VAR_23_:%.+]] = arith.minnumf [[VAR_22_]], [[CST_1_dot_270000_]] : f32 -// CHECK: [[VAR_24_:%.+]] = arith.fptosi [[VAR_23_]] : f32 to i32 -// CHECK: [[VAR_25_:%.+]] = arith.trunci [[VAR_24_]] : i32 to i8 -// CHECK: krnl.store [[VAR_25_]], [[RES_]]{{.}}[[VAR_4_]]{{.}} : memref<6xi8> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_8_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.cmpf oeq, [[VAR_9_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.select [[VAR_20_]], [[VAR_19_]], [[VAR_12_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_3_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[VAR_22_]], [[CST_minus_1_dot_280000_]] : f32 +// CHECK: [[VAR_24_:%.+]] = arith.minnumf [[VAR_23_]], [[CST_1_dot_270000_]] : f32 +// CHECK: [[VAR_25_:%.+]] = arith.fptosi [[VAR_24_]] : f32 to i32 +// CHECK: [[VAR_26_:%.+]] = arith.trunci [[VAR_25_]] : i32 to i8 +// CHECK: krnl.store [[VAR_26_]], [[RES_]]{{.}}[[VAR_5_]]{{.}} : memref<6xi8> // CHECK: } // CHECK: return [[RES_]] : memref<6xi8> // CHECK: }