diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp index 9a5a06b03a..9c0b66b432 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp @@ -261,7 +261,7 @@ class UnstickExpansionPattern : public OpRewritePattern { // Store f32 values back to the (normal layout) output. DimsExpr outputAF = SymListIE(inputAF); outputAF[E1] = outputAF[E1] + l; - create.vec.storeIE(vecF32H, alloc, outputAF, {}); + create.vec.storeIE(vecF32H, alloc, outputAF); create.vec.storeIE( vecF32L, alloc, outputAF, {litArchVLHalf.getValue()}); }); @@ -277,8 +277,8 @@ class UnstickExpansionPattern : public OpRewritePattern { Value vecF32L = convertOp.getResult(1); // Save into archVL value buffer. Value bufferF32 = create.mem.alignedAlloca(bufferType); - create.vec.storeIE(vecF32H, bufferF32, {litZero}, {}); - create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf}, {}); + create.vec.storeIE(vecF32H, bufferF32, {litZero}); + create.vec.storeIE(vecF32L, bufferF32, {litArchVLHalf}); // Save the remaining values as scalars. create.scf.forLoop(litZero.getValue(), remainingScalarValues.getValue(), 1, diff --git a/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp b/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp index 3e6c53fa14..ff806f1d3a 100644 --- a/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp +++ b/src/Conversion/KrnlToAffine/KrnlCopyFromBuffer.cpp @@ -124,7 +124,7 @@ class KrnlCopyFromBufferLowering : public ConversionPattern { // Nothing to write. } else { // Loop to copy the data. - createAffine.forLoopIE(zeroIE, writeUBs[i], 1, + createAffine.forLoopIE(zeroIE, writeUBs[i], 1, false /*parallel*/, [&](AffineBuilderKrnlMem &createAffine, ValueRange loopInd) { loopIndices.emplace_back(loopInd[0]); genCopyLoops(createAffine, enclosingScope, buffMemref, destMemref, diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 81c4b9768b..1ef0090049 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -1527,8 +1527,7 @@ static LogicalResult getPartiallyFlattenedSimdCode( create.krnl.simdIterateIE(zero, SymIE(simdUb), VL, simdOnly, useParallelInSimdLoop, inputs, inputAFs, {output}, {outputAF}, - [&](KrnlBuilder &kb, ArrayRef inputVals, - SmallVectorImpl &resVals, int64_t VL) { + {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { MultiDialectBuilder create(kb); Type currElementType = outputElementType; if (VL > 1) @@ -1557,9 +1556,9 @@ static LogicalResult getPartiallyFlattenedSimdCode( res = emitPostProcessingFor(rewriter, create.getLoc(), op, currElementType, accumulated); } - resVals.emplace_back(res); - }); // SIMD kernel. - }); // Outer loops. + return res; + }}); // SIMD kernel. + }); // Outer loops. rewriter.replaceOp(op, alloc); return success(); diff --git a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp index ccdd5705fe..e969ca8e85 100644 --- a/src/Conversion/ONNXToKrnl/Math/Reduction.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Reduction.cpp @@ -19,7 +19,8 @@ #include "src/Support/SmallVectorHelper.hpp" #define DEBUG_TYPE "lowering-to-krnl" -#define DEBUG_FORCE_SHUFFLE_REDUCTION 0 +#define DEBUG_FORCE_SHUFFLE_REDUCTION 0 /* should be 0 in repo */ +#define REDUCTION_MULTIPLE_OF_VL_ONLY 0 /* 0: improved;1: old, for debug */ using namespace mlir; @@ -279,9 +280,9 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, //===----------------------------------------------------------------------===// -using MDBuilder = - MultiDialectBuilder; +using MDBuilder = MultiDialectBuilder; //===----------------------------------------------------------------------===// // Helper function to perform reduction when an entire tensor is reduced to a @@ -330,47 +331,54 @@ void emitOneStepOfFullSIMDReduction(ConversionPatternRewriter &rewriter, rewriter, create.getLoc(), elementType)); } - BUILDER builder(create.vec); - builder.simdReduceIE( - lb, ub, VL, simdOnly, inputs, inputAFs, tmps, tmpAFs, outputs, outputAFs, - initVals, - /* reduction function */ - [&](const BUILDER &b, ArrayRef inputVals, ArrayRef tmpVals, - llvm::SmallVectorImpl &resultVals, int64_t VL) { + // Create the reduction functions. + llvm::SmallVector, 2> redBodyFnList; + llvm::SmallVector, 2> + postRedBodyFnList; + // Push functions for the first reduction. + redBodyFnList.emplace_back( + [&](const BUILDER &b, Value inputVal, Value tmpVal, int64_t VL) { Type currType = (VL > 1) ? vecType : elementType; - // First reduction, enqueue result. - Value accumulatedVec1 = emitScalarOpFor(rewriter, - create.getLoc(), op, currType, {tmpVals[0], inputVals[0]}); - resultVals.emplace_back(accumulatedVec1); - if (hasTwoRed) { - // Has a second reduction, also enqueue result. - Value accumulatedVec2 = emitScalarOpFor(rewriter, - create.getLoc(), op, currType, {tmpVals[1], inputVals[1]}); - resultVals.emplace_back(accumulatedVec2); - } - }, - /* post reduction function*/ - [&](const BUILDER &b, ArrayRef tmpVals, - llvm::SmallVectorImpl &scalarOutputs, int64_t VL) { + // Perform reduction of tmp and input. + return emitScalarOpFor( + rewriter, create.getLoc(), op, currType, {tmpVal, inputVal}); + }); + postRedBodyFnList.emplace_back( + [&](const BUILDER &b, Value tmpVal, int64_t VL) { // Perform horizontal reductions. - Value res1 = create.vec.reduction( - getCombiningKind(), tmpVals[0]); - scalarOutputs.emplace_back(res1); - if (hasTwoRed) { - Value res2 = create.vec.reduction( - getCombiningKind(), tmpVals[1]); - scalarOutputs.emplace_back(res2); - } - // Handle means if any. - if (tNum > 1) { /* parallel: do it for the final iteration only */ + Value scalarVal = + create.vec.reduction(getCombiningKind(), tmpVal); + if (tNum == 1) { /* parallel: do it for the final iteration only */ if (divideByMean()) - scalarOutputs[0] = - create.math.div(scalarOutputs[0], divisorForMean); - if (hasTwoRed && divideByMean()) - scalarOutputs[1] = - create.math.div(scalarOutputs[1], divisorForMean); + scalarVal = create.math.div(scalarVal, divisorForMean); } + return scalarVal; }); + if (hasTwoRed) { + // Push functions for the second reduction. + redBodyFnList.emplace_back( + [&](const BUILDER &b, Value inputVal, Value tmpVal, int64_t VL) { + Type currType = (VL > 1) ? vecType : elementType; + // Perform reduction of tmp and input. + return emitScalarOpFor( + rewriter, create.getLoc(), op, currType, {tmpVal, inputVal}); + }); + postRedBodyFnList.emplace_back( + [&](const BUILDER &b, Value tmpVal, int64_t VL) { + // Perform horizontal reductions. + Value scalarVal = create.vec.reduction( + getCombiningKind(), tmpVal); + if (tNum == 1) { /* parallel: do it for the final iteration only */ + if (divideByMean()) + scalarVal = create.math.div(scalarVal, divisorForMean); + } + return scalarVal; + }); + } + // Call simd reduce. + BUILDER builder(create.vec); + builder.simdReduceIE(lb, ub, VL, simdOnly, inputs, inputAFs, tmps, tmpAFs, + outputs, outputAFs, initVals, redBodyFnList, postRedBodyFnList); } template @@ -702,6 +710,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { bool parallelSimd = false; int64_t innermostLoopCollapse = 0; int64_t totVL = 1; + bool simdOnly = false; int64_t simdLoopStaticTripCount = 0; // With dynamic axes, use this @@ -755,7 +764,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { // of the natural SIMD width. Aka, we don't deal with SIMD of // partial vectors. GenOpMix mix = getGenOpMix(elementOutType, op); - bool simdOnly, canOverCompute = false; + bool canOverCompute = false; totVL = computeSuitableUnrollFactor(memRefInType, innermostLoopCollapse, mix, canOverCompute, simdLoopStaticTripCount, simdOnly); @@ -765,11 +774,15 @@ struct ONNXReductionOpLowering : public OpConversionPattern { // here. Some benchmarks have small trip counts (e.g. GPT2: 8). totVL = capVLForMaxUnroll(memRefInType, totVL, 1); } - // Current code gen scheme only support SIMD only scheme. +#if REDUCTION_MULTIPLE_OF_VL_ONLY + // Currently fails with krnl to affine without this. Should + // consider an affine simd iterate/reduce. onnx-mlir + // -shapeInformation=0:4x8 reducemean2.mlir -O3 -march=arm64 if (!simdOnly) { totVL = capVLForSimdOnly(memRefInType, totVL, simdLoopStaticTripCount); } +#endif LLVM_DEBUG(llvm::dbgs() << " SIMD: " << innermostLoopCollapse << " loops, totVL " << totVL << "\n"); if (totVL <= 1) { @@ -906,14 +919,14 @@ struct ONNXReductionOpLowering : public OpConversionPattern { if (horizontalSimd) { if (hasHorizontalSimdSupport) { genHorizontalSimdReduction(rewriter, create, op, elementOutType, input, - alloc, inRank, outRank, totVL, innermostLoopCollapse, isKeepdims, - divisorForMean, enableParallel); + alloc, inRank, outRank, totVL, simdOnly, innermostLoopCollapse, + isKeepdims, divisorForMean, enableParallel); onnxToKrnlSimdReport(op, /*successful*/ true, totVL, simdLoopStaticTripCount, "horizontal"); } else { genShuffleHorizontalSimdReduction(rewriter, create, op, elementOutType, - input, alloc, inRank, outRank, totVL, innermostLoopCollapse, - isKeepdims, divisorForMean, enableParallel); + input, alloc, inRank, outRank, totVL, simdOnly, + innermostLoopCollapse, isKeepdims, divisorForMean, enableParallel); onnxToKrnlSimdReport(op, /*successful*/ true, totVL, simdLoopStaticTripCount, "shuffle-horizontal"); } @@ -1054,46 +1067,38 @@ struct ONNXReductionOpLowering : public OpConversionPattern { void genOneHorizontalSimdReduction(ConversionPatternRewriter &rewriter, MDBuilder &create, Operation *op, Type elementType, VectorType vecType, Value tmpAlloca, Value flatInput, Value flatAlloc, Value initVec, - Value divisorForMean, ValueRange outLoopInd, Value simdUB, - int64_t VL) const { + Value divisorForMean, ValueRange outLoopInd, Value simdUB, int64_t VL, + bool simdOnly) const { IndexExpr lb = LitIE(0); IndexExpr ub = SymIE(simdUB); - bool fullySIMD = true; SmallVector outputAF = SymListIE(outLoopInd); SmallVector inputAF = outputAF; inputAF.emplace_back(lb); SmallVector tmpAF(2, lb); // tmpAlloc is 2D Value identity = getIdentityValue( rewriter, create.getLoc(), elementType); - create.krnl.simdReduceIE( - lb, ub, VL, fullySIMD, + create.krnl.simdReduceIE(lb, ub, VL, simdOnly, /* inputs*/ {flatInput}, {inputAF}, /* temp */ {tmpAlloca}, {tmpAF}, /* output */ {flatAlloc}, {outputAF}, /* init */ {identity}, /* reduction simd/scalar */ - [&](const KrnlBuilder &kb, ArrayRef inputVals, - ArrayRef tmpVals, llvm::SmallVectorImpl &resultVals, - int64_t VL) { - Value input = inputVals[0]; - Value tmp = tmpVals[0]; + {[&](const KrnlBuilder &kb, Value inputVal, Value tmpVal, int64_t VL) { Type type = VL > 1 ? vecType : elementType; - Value accumulatedVec = emitScalarOpFor( - rewriter, create.getLoc(), op, type, {tmp, input}); - resultVals.emplace_back(accumulatedVec); - }, + return emitScalarOpFor( + rewriter, create.getLoc(), op, type, {tmpVal, inputVal}); + }}, /* post processing */ - [&](const KrnlBuilder &kb, ArrayRef tmpVals, - llvm::SmallVectorImpl &scalarOutputs, int64_t VL) { - Value tmp = tmpVals[0]; + {[&](const KrnlBuilder &kb, Value tmpVal, int64_t VL) { + // Horizontal reduction. Value accumulatedVal = - create.vec.reduction(getCombiningKind(), tmp); - // other operation... + create.vec.reduction(getCombiningKind(), tmpVal); + // Other post reduction operation... if (divideByMean()) { accumulatedVal = create.math.div(accumulatedVal, divisorForMean); } - scalarOutputs.emplace_back(accumulatedVal); - }); + return accumulatedVal; + }}); } // We assume here that the hardware has an efficient SIMD horizontal @@ -1101,7 +1106,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { // reductions that needs to be performed. void genHorizontalSimdReduction(ConversionPatternRewriter &rewriter, MDBuilder &create, Operation *op, Type elementType, Value input, - Value alloc, int64_t inRank, int64_t outRank, int64_t VL, + Value alloc, int64_t inRank, int64_t outRank, int64_t VL, bool simdOnly, int64_t collapsedInnermostLoops, bool isKeepDims, Value divisorForMean, bool enableParallel) const { LLVM_DEBUG(llvm::dbgs() << "gen horizontal simd reduction\n"); @@ -1157,7 +1162,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { Value initVec = create.vec.splat(vecType, identity); genOneHorizontalSimdReduction(rewriter, create, op, elementType, vecType, tmpAlloca, flatInput, flatAlloc, initVec, divisorForMean, - outLoopInd, simdUB, VL); + outLoopInd, simdUB, VL, simdOnly); }); } @@ -1183,66 +1188,49 @@ struct ONNXReductionOpLowering : public OpConversionPattern { MDBuilder &create, Operation *op, Type elementType, VectorType vecType, Value tmpBlockedAlloca, Value flatInput, Value flatAlloc, Value initVec, Value divisorForMean, ValueRange blockedOutLoopInd, - IndexExpr blockedCurrIndex, Value simdUB, int64_t VL) const { - // Init temp memory to init values. - Value zero = create.math.constantIndex(0); - for (int64_t i = 0; i < VL; ++i) { - create.vec.store( - initVec, tmpBlockedAlloca, {create.math.constantIndex(i), zero}); - } - // First step: blocked simd loop. - ValueRange simdLoopDef = create.krnl.defineLoops(1); - ValueRange blockedSimdLoopDef = create.krnl.block(simdLoopDef[0], VL); - create.krnl.iterate(simdLoopDef, {blockedSimdLoopDef[0]}, {zero}, {simdUB}, - [&](KrnlBuilder &ck, ValueRange simdLoopInd) { - MDBuilder create(ck); - // Loop over blocked output loop, block guaranteed to be full. - for (int64_t i = 0; i < VL; ++i) { - IndexExpr offset = LitIE(i); - IndexExpr blockLocalIndIE = blockedCurrIndex + offset; - Value blockLocalInd = blockLocalIndIE.getValue(); - // All of the non-blocked loop, plus the inter tile index of the - // blocked loop, and the blocked simd loop. - SmallVector inAccessVals = - firstFew(blockedOutLoopInd, -2); - inAccessVals.emplace_back(blockLocalInd); - inAccessVals.emplace_back(simdLoopInd[0]); - Value inputVec = create.vec.load(vecType, flatInput, inAccessVals); - // The tmpInd value is between 0 and VL-1, and is local index - - // blocked index. - Value tmpInd = offset.getValue(); - Value tmpVec = - create.vec.load(vecType, tmpBlockedAlloca, {tmpInd, zero}); - // Sum into redVec - Value accumulatedVec = emitScalarOpFor( - rewriter, create.getLoc(), op, vecType, {tmpVec, inputVec}); - create.vec.store(accumulatedVec, tmpBlockedAlloca, {tmpInd, zero}); - } /* intra block output loop */ - }); /* blocked simd loop */ - // Step 2 - // Load all temp vectors. - SmallVector redIn, redOut; - for (int64_t i = 0; i < VL; ++i) { - Value val = create.vec.load( - vecType, tmpBlockedAlloca, {create.math.constantIndex(i), zero}); - redIn.emplace_back(val); - } - // Reduce all of the temp vectors at once. - auto redFct = [&](Value a, Value b) -> Value { - return emitScalarOpFor( - rewriter, create.getLoc(), op, vecType, {a, b}); - }; - create.vec.multiReduction(redIn, redFct, redOut); - // The redOut list should have one value with SIMD of VL. - assert(redOut.size() == 1 && "expected only one val"); - Value accumulatedVal = redOut[0]; - // Perform the mean computation if required. - if (divideByMean()) { - Value divisorForMeanVec = create.vec.splat(vecType, divisorForMean); - accumulatedVal = create.math.div(accumulatedVal, divisorForMeanVec); + IndexExpr blockedCurrIndex, Value simdUB, int64_t VL, + bool simdOnly) const { + IndexExpr zero = LitIE(0); + IndexExpr lb = zero; + IndexExpr ub = SymIE(simdUB); + int64_t rank = blockedOutLoopInd.size(); + DimsExpr inputAF = SymListIE(blockedOutLoopInd); + inputAF[rank - 1] = blockedCurrIndex; + inputAF.emplace_back(zero); + DimsExpr tmpAF = {zero, zero}; + DimsExpr outputAF = SymListIE(blockedOutLoopInd); + Value identity = getIdentityValue( + rewriter, create.getLoc(), elementType); + if (simdOnly) { + create.affine.simdReduce2DIE( + lb, ub, VL, simdOnly, flatInput, inputAF, tmpBlockedAlloca, tmpAF, + flatAlloc, outputAF, identity, + [&](const AffineBuilder &b, Value inputVal, Value tmpVal, + int64_t VL) { + Type type = VL > 1 ? vecType : elementType; + return emitScalarOpFor( + rewriter, b.getLoc(), op, type, {tmpVal, inputVal}); + }, + [&](const AffineBuilder &b, Value tmpVal, int VL) { + if (divideByMean()) + return create.math.div(tmpVal, divisorForMean); + return tmpVal; + }); + } else { + create.scf.simdReduce2DIE( // Affine fails with dynamic shapes. + lb, ub, VL, simdOnly, flatInput, inputAF, tmpBlockedAlloca, tmpAF, + flatAlloc, outputAF, identity, + [&](const SCFBuilder &b, Value inputVal, Value tmpVal, int64_t VL) { + Type type = VL > 1 ? vecType : elementType; + return emitScalarOpFor( + rewriter, b.getLoc(), op, type, {tmpVal, inputVal}); + }, + [&](const SCFBuilder &b, Value tmpVal, int VL) { + if (divideByMean()) + return create.math.div(tmpVal, divisorForMean); + return tmpVal; + }); } - // Store final values. - create.vec.store(accumulatedVal, flatAlloc, blockedOutLoopInd); } // Solution when there is no horizontal SIMD op support and that shuffle ops @@ -1257,7 +1245,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { void genShuffleHorizontalSimdReduction(ConversionPatternRewriter &rewriter, MDBuilder &create, Operation *op, Type elementType, Value input, - Value alloc, int64_t inRank, int64_t outRank, int64_t VL, + Value alloc, int64_t inRank, int64_t outRank, int64_t VL, bool simdOnly, int64_t collapsedInnermostLoops, bool isKeepDims, Value divisorForMean, bool enableParallel) const { @@ -1287,8 +1275,11 @@ struct ONNXReductionOpLowering : public OpConversionPattern { assert(flatOutRank == flatInRank - 1 && "wrong assumptions about dims"); // Parallelism only if output is not a scalar. - if (flatOutRank == 0) + if (flatOutRank == 0 && enableParallel) { enableParallel = false; + onnxToKrnlParallelReport( + op, false, -1, 0, "zero flat out rank for reduction shuffle h-simd"); + } // Compute type of small temp vector. MemRefType tmpBlockedType = MemRefType::get({VL, VL}, elementType); @@ -1304,11 +1295,11 @@ struct ONNXReductionOpLowering : public OpConversionPattern { SmallVector lbs(flatOutRank, LitIE(0)); if (enableParallel) { int64_t parId; - if (findSuitableParallelDimension(lbs, flatOutDims, 0, 1, parId, - /*min iter for going parallel*/ 64 * VL)) { - create.krnl.parallel(optimizedOutLoopDef[0]); - onnxToKrnlParallelReport( - op, true, 0, lbs[0], flatOutDims[0], "reduction shuffle h-simd"); + if (findSuitableParallelDimension(lbs, flatOutDims, 0, flatOutRank, parId, + /*min iter for going parallel*/ 8 * VL)) { + create.krnl.parallel(optimizedOutLoopDef[parId]); + onnxToKrnlParallelReport(op, true, parId, lbs[parId], + flatOutDims[parId], "reduction shuffle h-simd"); } else { onnxToKrnlParallelReport(op, false, 0, lbs[0], flatOutDims[0], "not enough work for reduction shuffle h-simd"); @@ -1350,7 +1341,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { genOneHorizontalSimdReduction(rewriter, create, op, elementType, vecType, tmpBlockedAlloca, flatInput, flatAlloc, initVec, divisorForMean, outLoopInd, - simdUB, VL); + simdUB, VL, simdOnly); }); /* for inside blocked loop */ }, [&](SCFBuilder &scf) { @@ -1359,7 +1350,7 @@ struct ONNXReductionOpLowering : public OpConversionPattern { genVlHorizontalSimdReduction(rewriter, create, op, elementType, vecType, tmpBlockedAlloca, flatInput, flatAlloc, initVec, divisorForMean, blockedOutLoopInd, blockedCurrIndex, simdUB, - VL); + VL, simdOnly); }); }); /* blocked out loop */ } diff --git a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp index f71741bd93..1eed53603a 100644 --- a/src/Conversion/ONNXToKrnl/NN/Normalization.cpp +++ b/src/Conversion/ONNXToKrnl/NN/Normalization.cpp @@ -255,7 +255,7 @@ struct ONNXInstanceNormalizationOpLowering for (int d = 0; d < rank - 2; ++d) inputAccessFct.emplace_back(spatial_loopInd[d]); // tmp += input[n,c, spatial dims] - Value oldSum = create.krnl.load(tmpMemRef, {}); + Value oldSum = create.krnl.load(tmpMemRef); Value val = create.krnl.load(inputMemRef, inputAccessFct); val = create.math.sub(val, mean); val = create.math.mul(val, val); diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 2567c4a1f4..775ee0cc35 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -70,8 +70,7 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, outputAF.emplace_back(zero); create.krnl.simdIterateIE(simdLb, simdUb, totVL, simdOnly, enableParallel, {flatInput}, {inputAF}, {flatAlloc}, {outputAF}, - [&](KrnlBuilder &kb, ArrayRef inputVals, - SmallVectorImpl &resVals, int64_t VL) { + {[&](const KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { MultiDialectBuilder create(kb); Value x = inputVals[0]; // Scale @@ -87,8 +86,8 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, // Saturate Value saturateX = create.math.clip(adjustX, qMin, qMax); Value res = create.math.cast(quantizedElementType, saturateX); - resVals.emplace_back(res); - }); + return res; + }}); if (totVL > 1) onnxToKrnlSimdReport(op, /*successful*/ true, totVL, simdLoopStaticTripCount, "quantizationLinear whole tensor"); diff --git a/src/Dialect/Krnl/DialectBuilder.cpp b/src/Dialect/Krnl/DialectBuilder.cpp index b07d7b09d5..541c3669b0 100644 --- a/src/Dialect/Krnl/DialectBuilder.cpp +++ b/src/Dialect/Krnl/DialectBuilder.cpp @@ -222,16 +222,28 @@ KrnlIterateOp KrnlBuilder::iterateIE(ValueRange originalLoops, }); } +void KrnlBuilder::forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, + bool useParallel, KrnlLoopBodyFn builderFn) const { + ValueRange originalLoopDef = defineLoops(1); + llvm::SmallVector optLoopDef(1, originalLoopDef[0]); + if (step > 1) { + // Block loop by step. + ValueRange blockedLoopDef = block(originalLoopDef[0], step); + optLoopDef[0] = blockedLoopDef[0]; + } + if (useParallel) + parallel(optLoopDef[0]); + iterateIE(originalLoopDef, optLoopDef, {lb}, {ub}, builderFn); +} + void KrnlBuilder::simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, bool useParallel, ArrayRef inputs, ArrayRef inputAFs, ArrayRef outputs, ArrayRef outputAFs, - function_ref inputVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - bodyBuilderFn) const { + ArrayRef iterateBodyFnList) const { onnx_mlir::impl::simdIterateIE(*this, lb, ub, VL, fullySimd, useParallel, inputs, inputAFs, outputs, outputAFs, - bodyBuilderFn); + iterateBodyFnList); } void KrnlBuilder::simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, @@ -239,20 +251,27 @@ void KrnlBuilder::simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, ArrayRef tmps, ArrayRef tmpAFs, ArrayRef outputs, ArrayRef outputAFs, ArrayRef initVals, /* reduction function (simd or scalar) */ - function_ref inputVals, - ArrayRef tmpVals, llvm::SmallVectorImpl &resultVals, - int64_t VL)> - reductionBuilderFn, + ArrayRef reductionBodyFnList, /* post reduction function (simd to scalar + post processing)*/ - function_ref tmpVals, - llvm::SmallVectorImpl &scalarOutputs, int64_t VL)> - postProcessingBuilderFn) const { + ArrayRef postReductionBodyFnList) const { onnx_mlir::impl::simdReduceIE(*this, lb, ub, VL, fullySimd, inputs, inputAFs, tmps, tmpAFs, outputs, outputAFs, initVals, - reductionBuilderFn, postProcessingBuilderFn); + reductionBodyFnList, postReductionBodyFnList); +} + +void KrnlBuilder::simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, + bool fullySimd, Value input, DimsExpr inputAF, Value tmp, DimsExpr tmpAF, + Value output, DimsExpr outputAF, Value initVal, + /* reduction functions (simd or scalar) */ + KrnlSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + KrnlSimdPostReductionBodyFn postReductionBodyFn) const { + onnx_mlir::impl::simdReduce2DIE(*this, lb, ub, VL, + fullySimd, input, inputAF, tmp, tmpAF, output, outputAF, initVal, + reductionBodyFn, postReductionBodyFn); } -void KrnlBuilder::yield(mlir::ValueRange iterArgs) const { +void KrnlBuilder::yield(ValueRange iterArgs) const { b().create(loc(), iterArgs); } diff --git a/src/Dialect/Krnl/DialectBuilder.hpp b/src/Dialect/Krnl/DialectBuilder.hpp index dd754b05dc..6a8d23097c 100644 --- a/src/Dialect/Krnl/DialectBuilder.hpp +++ b/src/Dialect/Krnl/DialectBuilder.hpp @@ -30,12 +30,12 @@ struct KrnlBuilder : public DialectBuilder { KrnlBuilder(const DialectBuilder &db) : DialectBuilder(db) {} virtual ~KrnlBuilder() {} + // Common load/store interface (krnl/affine/memref) // Add offsets (if any) to the least significant dims. mlir::Value load(mlir::Value memref, mlir::ValueRange indices = {}, mlir::ValueRange offsets = {}) const; mlir::Value loadIE(mlir::Value memref, mlir::ArrayRef indices = {}, mlir::ValueRange offsets = {}) const; - // Add offsets (if any) to the least significant dims. void store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}, mlir::ValueRange offsets = {}) const; void storeIE(mlir::Value val, mlir::Value memref, @@ -70,11 +70,12 @@ struct KrnlBuilder : public DialectBuilder { // Iterate over optimized loops given the original loops, lbs and ubs. Lambda // function implement the body of the loop, and receive a KRNL builder and the // loop indices. + using KrnlLoopBodyFn = + mlir::function_ref; + void iterate(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops, mlir::ValueRange lbs, mlir::ValueRange ubs, - mlir::function_ref - bodyBuilderFn) const; + KrnlLoopBodyFn bodyBuilderFn) const; mlir::KrnlIterateOp iterate(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops, mlir::ValueRange lbs, mlir::ValueRange ubs, mlir::ValueRange inits, @@ -87,10 +88,7 @@ struct KrnlBuilder : public DialectBuilder { // Same versions with Index Expressions for bounds. void iterateIE(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops, mlir::ArrayRef lbs, - mlir::ArrayRef ubs, - mlir::function_ref - bodyBuilderFn) const; + mlir::ArrayRef ubs, KrnlLoopBodyFn bodyBuilderFn) const; mlir::KrnlIterateOp iterateIE(mlir::ValueRange originalLoops, mlir::ValueRange optimizedLoops, mlir::ArrayRef lbs, mlir::ArrayRef ubs, mlir::ValueRange inits, @@ -98,20 +96,30 @@ struct KrnlBuilder : public DialectBuilder { mlir::ValueRange blockIters)> bodyBuilderFn) const; + // Common loop interface (krnl/affine/scf). + void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, bool useParallel, + KrnlLoopBodyFn builderFn) const; + + // Common simd loop interface (krnl/affine/scf). /* Iterate over a loop executing the loop body in SIMD mode (of vector length VL) from lb to ub. A scalar loop may execute up to VL-1 loop iterations when the trip count is not a multiple of VL. If fullySimd is true, then the call assumes that the trip count is a multiple of VL. - This call needs be given each of the memref inputs to the loop body, given - as an ordered pair memref value and its corresponding access function. Same - hold for all the memref outputs of the loop body. + This simdIterateIE needs be given each of the memref inputs to the loop + body, given as an ordered pair memref value and its corresponding access + function. Same hold for all the memref outputs of the loop body. + + The loop body is constructed by calling each of the KrnlSimdIterateBodyFn + given in the list. Each function is responsible for returning one output + value. The returned values are eventually stored in the output memrefs at a + location given by its respective output access function. - The loop body is given a KRNL builder, a list of loaded input (same order - as the input's memrefs and access functions). It will generate values that - must be placed in the result list in the same order as the output's memrefs - and access functions. + To generate their output, each KrnlSimdIterateBodyFn function is given + a KRNL builder, a list of loaded input (same order + as the input's memrefs and access functions), and the current VectorLength + (VL). VL is either the original VL or 1 (when executing in scalar mode). It will be the responsibility of this call to load each of the inputs and store each of the outputs. When operating in SIMD mode, every input and @@ -129,45 +137,61 @@ struct KrnlBuilder : public DialectBuilder { Dialect/Mlir/DialectBuilder.hpp.inc. */ + using KrnlSimdIterateBodyFn = impl::SimdIterateBodyFn; void simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, bool useParallel, mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, - mlir::function_ref inputVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - bodyBuilderFn) const; + mlir::ArrayRef bodyBuilderFnList) const; /* Works similarly as simdIterateIE, but performs a reduction to a single scalar per output value. Inputs must be strided in their innermost dimensions. Temps are used to hold the temporary results (partial results per SIMD lane), and the outputs have the scalar reduction outputs - Two functions are given: reductionBuilderFn to perform the partial - reductions into the temporary values tmps, finishing with up to VL partial - reductions - The second function: postProcessingBuilderFn performs the reductions of the - up to VL partial reductions into a final scalar reduction to be stored into - the outputs (a scalar value). For some reductions, post processing is also - needed, for example, mean reduction divide the accumulated sum by the - number of elements. That step is also performed here. + + Two function lists are given: a list of reductionBodyFn to perform the + partial reductions into the temporary values tmps, finishing with up to VL + partial reductions The second list of postReductionBodyFn perform the + reductions of the up to VL partial reductions into a final scalar reduction + to be stored into the outputs (a scalar value). For some reductions, post + processing is also needed, for example, mean reduction divide the + accumulated sum by the number of elements. That step is also performed + here. */ + using KrnlSimdReductionBodyFn = impl::SimdReductionBodyFn; + using KrnlSimdPostReductionBodyFn = + impl::SimdPostReductionBodyFn; + void simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, mlir::ArrayRef tmps, mlir::ArrayRef tmpAFs, mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, mlir::ArrayRef initVals, /* reduction function (simd or scalar) */ - mlir::function_ref inputVals, - mlir::ArrayRef tmpVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - reductionBuilderFn, + mlir::ArrayRef reductionBodyFnList, /* post reduction function (simd to scalar + post processing)*/ - mlir::function_ref tmpVals, - llvm::SmallVectorImpl &scalarOutputs, int64_t VL)> - postProcessingBuilderFn) const; + mlir::ArrayRef postReductionBodyFnList) + const; + + /* + Same as simdReduceIE, but perform VL reductions at once. It expect at least + VL iterations in the second to last dimension of inputs/outputs. + + Unlike simdReduceIE, the second function is for post processing only. In + simdReduceIE, that function was also used to reduce the SIMD temporary + reduction into a single scalar. + + Also, at this time, simdReduce2DIE process only one reduction at a time, + whereas simdReduceIE could process an arbitrary number of reductions. + */ + void simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::Value input, DimsExpr inputAF, mlir::Value tmp, DimsExpr tmpAF, + mlir::Value output, DimsExpr outputAF, mlir::Value initVal, + /* reduction functions (simd or scalar) */ + KrnlSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + KrnlSimdPostReductionBodyFn postReductionBodyFn) const; void yield(mlir::ValueRange iterArgs) const; diff --git a/src/Dialect/Mlir/DialectBuilder.cpp b/src/Dialect/Mlir/DialectBuilder.cpp index c77dfb5368..75fa59cb96 100644 --- a/src/Dialect/Mlir/DialectBuilder.cpp +++ b/src/Dialect/Mlir/DialectBuilder.cpp @@ -1726,11 +1726,11 @@ void SCFBuilder::ifThenElse(Value cond, } } -void SCFBuilder::forLoop(Value lowerBound, Value upperBound, int64_t step, - function_ref bodyFn) const { +void SCFBuilder::forLoop( + Value lb, Value ub, int64_t step, SCFLoopBodyFn bodyFn) const { MathBuilder createMath(*this); Value stepVal = createMath.constantIndex(step); - b().create(loc(), lowerBound, upperBound, stepVal, std::nullopt, + b().create(loc(), lb, ub, stepVal, std::nullopt, [&](OpBuilder &childBuilder, Location childLoc, Value inductionVar, ValueRange args) { SCFBuilder builder(childBuilder, childLoc); @@ -1739,10 +1739,20 @@ void SCFBuilder::forLoop(Value lowerBound, Value upperBound, int64_t step, }); } -void SCFBuilder::parallelLoops(ValueRange lowerBounds, ValueRange upperBounds, - ValueRange steps, - function_ref bodyFn) const { - b().create(loc(), lowerBounds, upperBounds, steps, +void SCFBuilder::forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, + bool useParallel, SCFLoopBodyFn bodyFn) const { + if (useParallel) { + MathBuilder createMath(*this); + Value stepVal = createMath.constantIndex(step); + parallelLoops({lb.getValue()}, {ub.getValue()}, {stepVal}, bodyFn); + } else { + forLoop(lb.getValue(), ub.getValue(), step, bodyFn); + } +} + +void SCFBuilder::parallelLoops(ValueRange lbs, ValueRange ubs, ValueRange steps, + SCFLoopBodyFn bodyFn) const { + b().create(loc(), lbs, ubs, steps, [&](OpBuilder &childBuilder, Location childLoc, ValueRange inductionVars) { SCFBuilder builder(childBuilder, childLoc); @@ -1757,12 +1767,9 @@ void SCFBuilder::simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, bool useParallel, ArrayRef inputs, ArrayRef inputAFs, ArrayRef outputs, ArrayRef outputAFs, - function_ref inputVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - bodyBuilderFn) const { + ArrayRef bodyFnList) const { onnx_mlir::impl::simdIterateIE(*this, lb, ub, VL, - fullySimd, useParallel, inputs, inputAFs, outputs, outputAFs, - bodyBuilderFn); + fullySimd, useParallel, inputs, inputAFs, outputs, outputAFs, bodyFnList); } void SCFBuilder::simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, @@ -1770,17 +1777,24 @@ void SCFBuilder::simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, ArrayRef tmps, ArrayRef tmpAFs, ArrayRef outputs, ArrayRef outputAFs, ArrayRef initVals, /* reduction function (simd or scalar) */ - function_ref inputVals, - ArrayRef tmpVals, llvm::SmallVectorImpl &resultVals, - int64_t VL)> - reductionBuilderFn, + mlir::ArrayRef reductionFnList, /* post reduction function (simd to scalar + post processing)*/ - function_ref tmpVals, - llvm::SmallVectorImpl &scalarOutputs, int64_t VL)> - postProcessingBuilderFn) const { + mlir::ArrayRef postReductionFnList) const { onnx_mlir::impl::simdReduceIE(*this, lb, ub, VL, fullySimd, inputs, inputAFs, tmps, tmpAFs, outputs, outputAFs, initVals, - reductionBuilderFn, postProcessingBuilderFn); + reductionFnList, postReductionFnList); +} + +void SCFBuilder::simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, + bool fullySimd, Value input, DimsExpr inputAF, Value tmp, DimsExpr tmpAF, + Value output, DimsExpr outputAF, Value initVal, + /* reduction functions (simd or scalar) */ + SCFSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + SCFSimdPostReductionBodyFn postReductionBodyFn) const { + onnx_mlir::impl::simdReduce2DIE(*this, lb, ub, VL, + fullySimd, input, inputAF, tmp, tmpAF, output, outputAF, initVal, + reductionBodyFn, postReductionBodyFn); } //===----------------------------------------------------------------------===// diff --git a/src/Dialect/Mlir/DialectBuilder.hpp b/src/Dialect/Mlir/DialectBuilder.hpp index 4ee8f62863..f1c65bb32c 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp +++ b/src/Dialect/Mlir/DialectBuilder.hpp @@ -254,13 +254,12 @@ struct MemRefBuilder final : DialectBuilder { // Constants static const int64_t defaultAlign; + // Common load/store interface (krnl/affine/memref) // Add offsets (if any) to the least significant memref dims. mlir::Value load(mlir::Value memref, mlir::ValueRange indices = {}, mlir::ValueRange offsets = {}) const; mlir::Value loadIE(mlir::Value memref, mlir::ArrayRef indices = {}, mlir::ValueRange offsets = {}) const; - - // Add offsets (if any) to the least significant memref dims. void store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}, mlir::ValueRange offsets = {}) const; void storeIE(mlir::Value val, mlir::Value memref, @@ -427,6 +426,34 @@ struct MemRefBuilder final : DialectBuilder { const; }; +//===----------------------------------------------------------------------===// +// Functions definitions for SIMD methods (simdIterate & simdReduce) +//===----------------------------------------------------------------------===// + +namespace impl { + +// For simdIterate: given a list of inputs, create one output value. +template +using SimdIterateBodyFn = std::function inputVals, int64_t VL)>; + +// For simdReduce: take one input & one temp reduction value, and generate the +// new reduction value. +template +using SimdReductionBodyFn = std::function; + +// For simdReduce: take one temp simd reduction value, create a scalar +// reduction, and possibly apply post processing to it (e.g. div by number of +// elements). +// +// For simdReduce2D: only the post processing. Reduction is done before. +template +using SimdPostReductionBodyFn = std::function; + +} // namespace impl + //===----------------------------------------------------------------------===// // Structured Control Flow (SCF) Builder //===----------------------------------------------------------------------===// @@ -442,42 +469,48 @@ struct SCFBuilder final : DialectBuilder { void ifThenElse(mlir::Value cond, mlir::function_ref thenFn, mlir::function_ref elseFn = nullptr) const; - // Create a for loop. - void forLoop(mlir::Value lowerBound, mlir::Value upperBound, int64_t step, - mlir::function_ref bodyFn) const; - // Create a parallel for loop. - void parallelLoops(mlir::ValueRange lowerBounds, mlir::ValueRange upperBounds, - mlir::ValueRange steps, + // Common loop interface (krnl/affine/scf). + using SCFLoopBodyFn = + mlir::function_ref; + void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, bool useParallel, mlir::function_ref bodyFn) const; + // Custom interface + void forLoop( + mlir::Value lb, mlir::Value ub, int64_t step, SCFLoopBodyFn bodyFn) const; + void parallelLoops(mlir::ValueRange lbs, mlir::ValueRange ubs, + mlir::ValueRange steps, SCFLoopBodyFn bodyFn) const; + void yield() const; + // Common simd loop interface (krnl/affine/scf). // For detailed description, see KrnlBuilder.hpp file. + using SCFSimdIterateBodyFn = impl::SimdIterateBodyFn; void simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, bool useParallel, mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, - mlir::function_ref inputVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - bodyBuilderFn) const; + mlir::ArrayRef simdIterateBodyList) const; // For detailed description, see KrnlBuilder.hpp file. + using SCFSimdReductionBodyFn = impl::SimdReductionBodyFn; + using SCFSimdPostReductionBodyFn = impl::SimdPostReductionBodyFn; void simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, mlir::ArrayRef temps, mlir::ArrayRef tempAFs, mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, mlir::ArrayRef initVals, /* reduction function (simd or scalar) */ - mlir::function_ref inputVals, - mlir::ArrayRef tmpVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - reductionBuilderFn, + mlir::ArrayRef simdReductionBodyFnList, /* post reduction function (simd to scalar + post processing)*/ - mlir::function_ref tmpVals, - llvm::SmallVectorImpl &scalarOutputs, int64_t VL)> - postProcessingBuilderFn) const; + mlir::ArrayRef simdPostReductionBodyFnList) + const; + void simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::Value input, DimsExpr inputAF, mlir::Value tmp, DimsExpr tmpAF, + mlir::Value output, DimsExpr outputAF, mlir::Value initVal, + /* reduction functions (simd or scalar) */ + SCFSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + SCFSimdPostReductionBodyFn postReductionBodyFn) const; }; //===----------------------------------------------------------------------===// @@ -558,13 +591,12 @@ struct GenericAffineBuilder final : DialectBuilder { GenericAffineBuilder(const DialectBuilder &db) : DialectBuilder(db) {} virtual ~GenericAffineBuilder() {} + // Common load/store interface (krnl/affine/memref) // Add offsets (if any) to the least significant memref dims. mlir::Value load(mlir::Value memref, mlir::ValueRange indices = {}, mlir::ValueRange offsets = {}) const; mlir::Value loadIE(mlir::Value memref, mlir::ArrayRef indices = {}, mlir::ValueRange offsets = {}) const; - - // Add offsets (if any) to the least significant memref dims. void store(mlir::Value val, mlir::Value memref, mlir::ValueRange indices = {}, mlir::ValueRange offsets = {}) const; void storeIE(mlir::Value val, mlir::Value memref, @@ -575,13 +607,48 @@ struct GenericAffineBuilder final : DialectBuilder { mlir::ValueRange indices, bool isWrite, unsigned localityHint, bool isDataCache = true); + // Common loop interface (krnl/affine/scf). + using GenericAffineLoopBodyFn = + mlir::function_ref; + void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, bool useParallel, + GenericAffineLoopBodyFn builderFn) const; + + // Custom interface void forLoopIE(IndexExpr lb, IndexExpr ub, int64_t step, - mlir::function_ref - builderFn) const; + GenericAffineLoopBodyFn builderFn) const; // Sequential only. void forLoopsIE(mlir::ArrayRef lbs, mlir::ArrayRef ubs, - mlir::ArrayRef steps, - mlir::function_ref - builderFn) const; + mlir::ArrayRef steps, GenericAffineLoopBodyFn builderFn) const; + + // Common simd loop interface (krnl/affine/scf). + using GenericAffineSimdIterateBodyFn = + impl::SimdIterateBodyFn>; + void simdIterateIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + bool useParallel, mlir::ArrayRef inputs, + mlir::ArrayRef inputAFs, mlir::ArrayRef outputs, + mlir::ArrayRef outputAFs, + mlir::ArrayRef simdIterateBodyList) const; + + using GenericAffineSimdReductionBodyFn = + impl::SimdReductionBodyFn>; + using GenericAffineSimdPostReductionBodyFn = + impl::SimdPostReductionBodyFn>; + void simdReduceIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, + mlir::ArrayRef temps, mlir::ArrayRef tempAFs, + mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, + mlir::ArrayRef initVals, + /* reduction function (simd or scalar) */ + mlir::ArrayRef simdReductionBodyFnList, + /* post reduction function (simd to scalar + post processing)*/ + mlir::ArrayRef + simdPostReductionBodyFnList) const; + void simdReduce2DIE(IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, + mlir::Value input, DimsExpr inputAF, mlir::Value tmp, DimsExpr tmpAF, + mlir::Value output, DimsExpr outputAF, mlir::Value initVal, + /* reduction functions (simd or scalar) */ + GenericAffineSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + GenericAffineSimdPostReductionBodyFn postReductionBodyFn) const; // This if then else construct has no arguments to the blocks. void ifThenElseIE(IndexExprScope &scope, mlir::ArrayRef conditions, @@ -595,7 +662,7 @@ struct GenericAffineBuilder final : DialectBuilder { void yield() const; private: - // Support for multiple forLoopIE loops. + // Support for multiple for loops. void recursionForLoopsIE(mlir::ArrayRef lbs, mlir::ArrayRef ubs, mlir::ArrayRef steps, llvm::SmallVectorImpl &loopIndices, diff --git a/src/Dialect/Mlir/DialectBuilder.hpp.inc b/src/Dialect/Mlir/DialectBuilder.hpp.inc index fc1f17a981..6e84be9931 100644 --- a/src/Dialect/Mlir/DialectBuilder.hpp.inc +++ b/src/Dialect/Mlir/DialectBuilder.hpp.inc @@ -25,7 +25,7 @@ // Templates for load / store //===----------------------------------------------------------------------===// -namespace impl { +namespace impl { // Hide support for loads / stores in impl namespace. template mlir::Value load(const BUILDER &b, mlir::Value memref, mlir::ValueRange indices, @@ -83,176 +83,8 @@ void storeIE(const BUILDER &b, mlir::Value val, mlir::Value memref, IndexExpr::getValues(indices, indexValues); store(b, val, memref, indexValues, offsets); } -} // namespace impl - -//===----------------------------------------------------------------------===// -// Templates for GenericAffineBuilder -//===----------------------------------------------------------------------===// - -template -mlir::Value GenericAffineBuilder::load(mlir::Value memref, - mlir::ValueRange indices, mlir::ValueRange offsets) const { - return onnx_mlir::impl::load( - *this, memref, indices, offsets); -} - -template -mlir::Value GenericAffineBuilder::loadIE(mlir::Value memref, - mlir::ArrayRef indices, mlir::ValueRange offsets) const { - return onnx_mlir::impl::loadIE( - *this, memref, indices, offsets); -} - -template -inline void GenericAffineBuilder::store(mlir::Value val, - mlir::Value memref, mlir::ValueRange indices, - mlir::ValueRange offsets) const { - onnx_mlir::impl::store( - *this, val, memref, indices, offsets); -} - -template -inline void GenericAffineBuilder::storeIE(mlir::Value val, - mlir::Value memref, mlir::ArrayRef indices, - mlir::ValueRange offsets) const { - onnx_mlir::impl::storeIE( - *this, val, memref, indices, offsets); -} - -template -inline mlir::Operation *GenericAffineBuilder::prefetch( - mlir::Value memref, mlir::AffineMap map, mlir::ValueRange indices, - bool isWrite, unsigned localityHint, bool isDataCache) { - llvm::SmallVector indexArray(indices); - return b().template create( - loc(), memref, map, indexArray, isWrite, localityHint, isDataCache); -} - -template -inline void GenericAffineBuilder::forLoopIE(IndexExpr lb, - IndexExpr ub, int64_t step, - mlir::function_ref - builderFn) const { - // Transform IndexExpressions into value maps and list of operands. - mlir::AffineMap lbMap, ubMap; - llvm::SmallVector lbOperands, ubOperands; - lb.getAffineMapAndOperands(lbMap, lbOperands); - ub.getAffineMapAndOperands(ubMap, ubOperands); - // Create affine for. - b().template create(loc(), lbOperands, lbMap, - ubOperands, ubMap, step, mlir::ValueRange{}, - [&](mlir::OpBuilder &b, mlir::Location loc, mlir::Value index, - mlir::ValueRange args) { - GenericAffineBuilder createAffine(b, loc); - builderFn(createAffine, {index}); - createAffine.yield(); - }); -} - -template -inline void GenericAffineBuilder::forLoopsIE( - mlir::ArrayRef lbs, mlir::ArrayRef ubs, - mlir::ArrayRef steps, - mlir::function_ref - builderFn) const { - assert(lbs.size() == ubs.size() && "expected identical sizes"); - assert(lbs.size() == steps.size() && "expected identical sizes"); - llvm::SmallVector loopIndices; - recursionForLoopsIE(lbs, ubs, steps, loopIndices, builderFn); -} - -// This if then else construct has no arguments to the blocks. -template -inline void GenericAffineBuilder::ifThenElseIE( - IndexExprScope &scope, mlir::ArrayRef conditions, - mlir::function_ref thenFn, - mlir::function_ref elseFn) const { - int64_t rank = conditions.size(); - llvm::SmallVector affineCond; - bool allTrue = true; - bool allFalse = true; - for (IndexExpr c : conditions) { - assert(c.isAffine() && "conditions expected to be affine"); - affineCond.emplace_back(c.getAffineExpr()); - if (c.isLiteral()) { - if (c.getLiteral() < 0) // Inequality is expr >= 0, test if false. - allTrue = false; - if (c.getLiteral() >= 0) // Inequality is expr >= 0, test if true. - allFalse = false; - } else { - allTrue = allFalse = false; - } - } - llvm::SmallVector isEq(rank, false); - auto inset = mlir::IntegerSet::get( - scope.getNumDims(), scope.getNumSymbols(), affineCond, isEq); - llvm::SmallVector dimAndSymbolList; - scope.getDimAndSymbolList(dimAndSymbolList); - auto ifOp = b().template create( - loc(), inset, dimAndSymbolList, true); - mlir::Block *thenBlock = ifOp.getThenBlock(); - mlir::Block *elseBlock = ifOp.getElseBlock(); - if (!allFalse) { - appendToBlock(thenBlock, [&](mlir::ValueRange args) { - GenericAffineBuilder createAffine(b(), loc()); - thenFn(createAffine); - }); - } - if (!allTrue) { - appendToBlock(elseBlock, [&](mlir::ValueRange args) { - GenericAffineBuilder createAffine(b(), loc()); - elseFn(createAffine); - }); - } -} - -template -mlir::Value GenericAffineBuilder::apply( - mlir::AffineMap map, mlir::ValueRange operands) const { - return b().template create(loc(), map, operands); -} -template -inline void GenericAffineBuilder::yield() const { - b().template create(loc()); -} - -// Support for multiple forLoopIE loops. -template -void GenericAffineBuilder::recursionForLoopsIE( - mlir::ArrayRef lbs, mlir::ArrayRef ubs, - mlir::ArrayRef steps, - llvm::SmallVectorImpl &loopIndices, - mlir::function_ref - builderFn) const { - int d = loopIndices.size(); - if (d < (int)lbs.size()) { - // Issue a loop and recurse again. - forLoopIE(lbs[d], ubs[d], steps[d], - [&](GenericAffineBuilder &createAffine, mlir::ValueRange loopInd) { - loopIndices.emplace_back(loopInd[0]); - recursionForLoopsIE(lbs, ubs, steps, loopIndices, builderFn); - }); - } else { - // Call lambda function - GenericAffineBuilder createAffine(b(), loc()); - builderFn(createAffine, loopIndices); - } -} - -// Support for adding blocks. -template -inline void GenericAffineBuilder::appendToBlock( - mlir::Block *block, - mlir::function_ref builderFn) const { - mlir::OpBuilder::InsertionGuard guard(b()); - if (block->empty() || - !block->back().mightHaveTrait()) { - b().setInsertionPointToEnd(block); - } else - b().setInsertionPoint(&block->back()); - builderFn(block->getArguments()); -} +} // namespace impl //===----------------------------------------------------------------------===// // Templates for SIMD code gen (instantiated for KRNL and SCF builders) @@ -297,7 +129,7 @@ krnl.iterate(loop i from 0 to 256) { fullySimd=true, useParallel=false, // loop options inputs={A, B}, inputAFs={aAF, bAF}, // inputs outputs={R}, outputAFs={rAF}, // outputs - krnl) // lambda function for kernel + {krnl}) // lambda function for kernel 3) Krnl for SIMD loop @@ -307,14 +139,13 @@ krnl.iterate(loop i from 0 to 256) { c) list of results values, that must be enqueued by the kernel d) totVL used for the loop (VL for simd, 1 for scalar) - The same kernel will be used in a SIMD context, in which the inputs and + The same kernels will be used in a SIMD context, in which the inputs and outputs must be vectors of VL elements, or in a scalar context, in which the inputs and outputs must be scalars. In our example, the kernel is as follows - [&](KrnlBuilder &kb, ArrayRef inputVals, - SmallVectorImpl &resVals, int64_t VL) { + [&](KrnlBuilder &kb, ArrayRef inputVals, int64_t VL) { MultiDialectBuilder create(kb); Value aVal = inputVals[0]; // simd or scalar Value bVal = inputVals[1]; // simd or scalar @@ -322,7 +153,7 @@ krnl.iterate(loop i from 0 to 256) { Value newVal = create.math.add(aVal, bVal); // simd or scalar newVal = create.math.add(newVal, cVal); // if newVal is simd, cVal is // splatted - res.emplace_back(newVal); // Save simd or scalar result. + return newVal; // Save simd or scalar result. } The krnl.simdIterateIE will be in charge of loading and saving the values in @@ -334,18 +165,20 @@ krnl.iterate(loop i from 0 to 256) { (either totVL>1 or 1). */ +// Definition of SimdIterateBodyFn, see Mlir/DialectBuilder.hpp + template void simdIterateIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, bool useParallel, mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, - mlir::function_ref inputVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - bodyBuilderFn) { + mlir::ArrayRef> iterateBodyList) { int64_t inputNum = inputs.size(); assert(inputAFs.size() == inputs.size() && "expected same size"); int64_t outputNum = outputs.size(); assert(outputAFs.size() == outputs.size() && "expected same size"); + int64_t fnNum = iterateBodyList.size(); + assert((int64_t)fnNum == outputNum && "expect 1 loop function per output"); if (VL > 1) { // Want SIMD, execute full SIMD loops blocked by VL. @@ -379,21 +212,19 @@ void simdIterateIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, if (MemRefBuilder::hasOneElementInInnermostDims(input, 1)) { // Has a reference with a scalar innermost dim, just load as a // scalar. No need to add the induction variable. - mlir::Value scalarVal = createMem.loadIE(input, AF); - vecInputVals.emplace_back(scalarVal); + vecInputVals.emplace_back(createMem.loadIE(input, AF)); } else { // Have a vector. auto vecType = mlir::VectorType::get({VL}, type.getElementType()); AF[rank - 1] = AF[rank - 1] + ind; // Add induction var. - mlir::Value vecVal = createVec.loadIE(vecType, input, AF); - vecInputVals.emplace_back(vecVal); + vecInputVals.emplace_back(createVec.loadIE(vecType, input, AF)); } } // Call the method to compute the values. llvm::SmallVector vecResVals; - bodyBuilderFn(b, vecInputVals, vecResVals, VL); - assert((int64_t)vecResVals.size() == outputNum && - "loop body with incorrect number of results"); + for (int64_t f = 0; f < outputNum; ++f) { + vecResVals.emplace_back(iterateBodyList[f](b, vecInputVals, VL)); + } // Store all the outputs as vectors of VL values, for (int64_t i = 0; i < outputNum; ++i) { auto type = mlir::cast(outputs[i].getType()); @@ -405,26 +236,13 @@ void simdIterateIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, } }; - // Invocation of the (possibly parallel) SIMD loop. - if constexpr (std::is_same::value) { - // Use KRNL interface - mlir::ValueRange loopDef = builder.defineLoops(1); - mlir::ValueRange blockedLoopDef = builder.block(loopDef[0], VL); - if (useParallel) - builder.parallel({blockedLoopDef[0]}); - builder.iterateIE( - loopDef, {blockedLoopDef[0]}, {lb}, {simdUb}, simdLoopBody); - } else if constexpr (std::is_same::value) { - if (useParallel) { - IndexExpr litVL = LitIE(VL); - builder.parallelLoops({lb.getValue()}, {simdUb.getValue()}, - {litVL.getValue()}, simdLoopBody); - } else { - builder.forLoop(lb.getValue(), simdUb.getValue(), VL, simdLoopBody); - } - } else { + // Invocation of the (possibly parallel) SIMD loop. + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) + builder.forLoopIE(lb, simdUb, VL, useParallel, simdLoopBody); + else llvm_unreachable("BUILDER type not supported\n"); - } if (fullySimd) // Asserted that we only have SIMD iterations, we are done. @@ -462,19 +280,17 @@ void simdIterateIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, if (MemRefBuilder::hasOneElementInInnermostDims(input, 1)) { // Has a reference with a scalar innermost dim, just load as a // scalar. No need to add the induction variable. - mlir::Value scalarVal = createMem.loadIE(input, AF); - scalarInputVals.emplace_back(scalarVal); + scalarInputVals.emplace_back(createMem.loadIE(input, AF)); } else { AF[rank - 1] = AF[rank - 1] + ind; - mlir::Value scalarVal = createMem.loadIE(input, AF); - scalarInputVals.emplace_back(scalarVal); + scalarInputVals.emplace_back(createMem.loadIE(input, AF)); } } // Call the method to compute the values. llvm::SmallVector scalarResVals; - bodyBuilderFn(b, scalarInputVals, scalarResVals, /*VL*/ 1); - assert((int64_t)scalarResVals.size() == outputNum && - "loop body with incorrect number of results"); + for (int64_t f = 0; f < outputNum; ++f) { + scalarResVals.emplace_back(iterateBodyList[f](b, scalarInputVals, 1)); + } // Store all the outputs as vectors of VL values, for (int64_t i = 0; i < outputNum; ++i) { auto type = mlir::cast(outputs[i].getType()); @@ -487,55 +303,72 @@ void simdIterateIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, }; // Invocation of the scalar loop. - if constexpr (std::is_same::value) { - // Use KRNL dialect. - mlir::ValueRange loopDef = builder.defineLoops(1); - builder.iterateIE(loopDef, loopDef, {lb}, {ub}, scalarLoopBody); - } else if constexpr (std::is_same::value) { - builder.forLoop(lb.getValue(), ub.getValue(), 1, scalarLoopBody); - } else { + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) + builder.forLoopIE(lb, ub, 1, false /*parallel*/, scalarLoopBody); + else llvm_unreachable("BUILDER type not supported\n"); - } } +/* + Note that because reductions are always between 2 values, the reduction + function takes 1 input and one temp value, where the temp contains the partial + result. So if we have 2 reductions (aka 2 outputs), we also need 2 inputs and + 2 temp. A call to function reductionBodyFnList[k] (namely the kth entry in the + list) will be instantiated with the kth input value and the kth temp, and its + result is ultimately saved into the kth output. + + This was not the case for simdIterateIE, where all of the inputs are provided + to each of the functions computing one output. Here we only pass a pair of + input & temp value to each function. + + This is reflected in the Body types below. + + Allows calls with no outputs and no post-processing functions. In such case, + only perform the reductions into the tmps. +*/ + +// Definition of SimdReductionBodyFn & SimdPostReductionBodyFn, see +// Mlir/DialectBuilder.hpp + template void simdReduceIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, mlir::ArrayRef tmps, mlir::ArrayRef tmpAFs, mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, mlir::ArrayRef initVals, - /* reduction function (simd or scalar) */ - mlir::function_ref inputVals, - mlir::ArrayRef tmpVals, - llvm::SmallVectorImpl &resultVals, int64_t VL)> - reductionBuilderFn, - /* post reduction function (simd to scalar + post processing)*/ - mlir::function_ref tmpVals, - llvm::SmallVectorImpl &scalarOutputs, int64_t VL)> - postProcessingBuilderFn) { + /* reduction functions (simd or scalar) */ + mlir::ArrayRef> reductionBodyFnList, + /* post reduction functions (simd to scalar + post processing)*/ + mlir::ArrayRef> postReductionBodyFnList) { MultiDialectBuilder create(builder); MEM_BUILDER createMem(builder); - int64_t inputSize = inputs.size(); - int64_t outputSize = outputs.size(); - assert((int64_t)inputAFs.size() == inputSize && "expect same input size"); - assert(tmps.size() == tmpAFs.size() && "expect same tmp size"); - assert((int64_t)outputAFs.size() == outputSize && "expect output same size"); - assert((int64_t)tmps.size() == outputSize && "expect 1 tmp per output"); - assert((int64_t)initVals.size() == outputSize && "expect 1 init per output"); + uint64_t inputSize = inputs.size(); + uint64_t tmpSize = tmps.size(); + uint64_t outputSize = outputs.size(); + // Test same number of values & AFs. + assert(inputAFs.size() == inputSize && "expect same input size"); + assert(tmpAFs.size() == tmpSize && "expect same tmps size"); + assert(outputAFs.size() == outputSize && "expect output same size"); + // Same number of init, reduction functions, tmps as input. + assert(reductionBodyFnList.size() == inputSize && "1 red fn per input"); + assert(tmpSize == inputSize && "expect 1 tmp per input"); + assert(initVals.size() == inputSize && "expect 1 init per input"); + // Same number of post reductions as output. + assert(postReductionBodyFnList.size() == outputSize && "1 red fn per output"); // Gather element and vector types and perform the inits. Do it in SIMD mode // regardless. llvm::SmallVector vectorTypes; - for (int64_t o = 0; o < outputSize; ++o) { - mlir::Value initVal = initVals[o]; + for (uint64_t i = 0; i < inputSize; ++i) { + mlir::Value initVal = initVals[i]; mlir::Type elementType = initVal.getType(); auto vectorType = mlir::VectorType::get({VL}, elementType); vectorTypes.emplace_back(vectorType); mlir::Value initVec = create.vec.splat(vectorType, initVal); - create.vec.storeIE(initVec, tmps[o], tmpAFs[o], {}); + create.vec.storeIE(initVec, tmps[i], tmpAFs[i]); } if (VL > 1) { // Logic: see simdIterateIE. @@ -548,46 +381,38 @@ void simdReduceIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, MultiDialectBuilder create(b); // Load inputs in SIMD mode, indexed by loopInd[0] in innermost dim. llvm::SmallVector inputVals; - for (int64_t i = 0; i < inputSize; ++i) { + for (uint64_t i = 0; i < inputSize; ++i) { auto inputType = mlir::cast(inputs[i].getType()); auto vecType = mlir::VectorType::get({VL}, inputType.getElementType()); - mlir::Value inputVal = - create.vec.loadIE(vecType, inputs[i], inputAFs[i], {loopInd[0]}); - inputVals.emplace_back(inputVal); + inputVals.emplace_back( + create.vec.loadIE(vecType, inputs[i], inputAFs[i], {loopInd[0]})); } // Load tmp value in SIMD mode (no indexing, same value over & over). llvm::SmallVector tmpVals; - for (int64_t o = 0; o < outputSize; ++o) { - mlir::Value tmpVal = - create.vec.loadIE(vectorTypes[o], tmps[o], tmpAFs[o], {}); - tmpVals.emplace_back(tmpVal); + for (uint64_t i = 0; i < inputSize; ++i) { + tmpVals.emplace_back( + create.vec.loadIE(vectorTypes[i], tmps[i], tmpAFs[i])); } - // Call reduction. + // Call reduction, one per function each with their input and tmp value. llvm::SmallVector resultVals; - reductionBuilderFn(b, inputVals, tmpVals, resultVals, VL); - assert((int64_t)resultVals.size() == outputSize && - "expect ouputSize results"); + for (uint64_t i = 0; i < inputSize; ++i) { + resultVals.emplace_back( + reductionBodyFnList[i](b, inputVals[i], tmpVals[i], VL)); + } // Save tmp values in SIMD mode. - for (int64_t o = 0; o < outputSize; ++o) { - create.vec.storeIE(resultVals[o], tmps[o], tmpAFs[o], {}); + for (uint64_t i = 0; i < inputSize; ++i) { + create.vec.storeIE(resultVals[i], tmps[i], tmpAFs[i]); } }; // Want SIMD, execute full SIMD loops reductions blocked by VL. // Perform SIMD reduction: iterates over all SIMD vectors. - - if constexpr (std::is_same::value) { - // Implementation with Krnl. - mlir::ValueRange loopDef = builder.defineLoops(1); - mlir::ValueRange blockedLoopDef = builder.block(loopDef[0], VL); - builder.iterateIE( - loopDef, {blockedLoopDef[0]}, {lb}, {simdUb}, simdLoopBody); - } else if constexpr (std::is_same::value) { - // Implementation with SCF. - builder.forLoop(lb.getValue(), simdUb.getValue(), VL, simdLoopBody); - } else { + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) + builder.forLoopIE(lb, simdUb, VL, false /*parallel*/, simdLoopBody); + else llvm_unreachable("BUILDER type not supported"); - } if (fullySimd) { // No leftovers, no additional iterations to be done. @@ -622,59 +447,363 @@ void simdReduceIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, // We now perform sequential reduction in the tmps 1st element. Load // inputs in sequential mode indexed by loopInd[0] in innermost dim. llvm::SmallVector inputVals; - for (int64_t i = 0; i < inputSize; ++i) { - mlir::Value inputVal = - createMem.loadIE(inputs[i], inputAFs[i], {loopInd[0]}); - inputVals.emplace_back(inputVal); + for (uint64_t i = 0; i < inputSize; ++i) { + inputVals.emplace_back( + createMem.loadIE(inputs[i], inputAFs[i], {loopInd[0]})); } // Load tmps in scalar mode (no indexing, same value over & over). llvm::SmallVector tmpVals; - for (int64_t o = 0; o < outputSize; ++o) { - mlir::Value tmpVal = createMem.loadIE(tmps[o], tmpAFs[o], {}); - tmpVals.emplace_back(tmpVal); + for (uint64_t i = 0; i < inputSize; ++i) { + tmpVals.emplace_back(createMem.loadIE(tmps[i], tmpAFs[i])); } // Call reduction. llvm::SmallVector resultVals; - reductionBuilderFn(b, inputVals, tmpVals, resultVals, 1); - assert((int64_t)resultVals.size() == outputSize && - "expect ouputSize results"); + for (uint64_t i = 0; i < inputSize; ++i) { + resultVals.emplace_back( + reductionBodyFnList[i](b, inputVals[i], tmpVals[i], 1)); + } // Save tmp values in sequential mode. - for (int64_t o = 0; o < outputSize; ++o) { - createMem.storeIE(resultVals[o], tmps[o], tmpAFs[o], {}); + for (uint64_t i = 0; i < inputSize; ++i) { + createMem.storeIE(resultVals[i], tmps[i], tmpAFs[i]); } }; // Perform scalar loop. - if constexpr (std::is_same::value) { - // Implementation with Krnl. - mlir::ValueRange loopDef = builder.defineLoops(1); - builder.iterateIE(loopDef, loopDef, {lb}, {ub}, scalarLoopBody); - } else if constexpr (std::is_same::value) { - // Implementation with SCF. - builder.forLoop(lb.getValue(), ub.getValue(), 1, scalarLoopBody); - } else { + if constexpr (std::is_same::value || + std::is_same::value || + std::is_same::value) + builder.forLoopIE(lb, ub, 1, false /*parallel*/, scalarLoopBody); + else llvm_unreachable("BUILDER type not supported"); - } } + if (outputSize == 0) + return; // No outputs, we are done. + // Now perform post processing. Load all tmps. + assert(tmpSize == outputSize && "expect one tmp per output"); llvm::SmallVector tmpVals; - for (int64_t o = 0; o < outputSize; ++o) { + for (uint64_t o = 0; o < outputSize; ++o) { // Load tmp in vector mode. - mlir::Value tmpVal = - create.vec.loadIE(vectorTypes[o], tmps[o], tmpAFs[o], {}); - tmpVals.emplace_back(tmpVal); + tmpVals.emplace_back(create.vec.loadIE(vectorTypes[o], tmps[o], tmpAFs[o])); } - llvm::SmallVector scalarOutputs; + llvm::SmallVector finalResults; // Invoke the post processing operations, which takes each tmp vector and // reduces it to a scalar. - postProcessingBuilderFn(builder, tmpVals, scalarOutputs, VL); - assert((int64_t)scalarOutputs.size() == outputSize && - "expect outputSize results"); + for (uint64_t o = 0; o < outputSize; ++o) { + finalResults.emplace_back( + postReductionBodyFnList[o](builder, tmpVals[o], 1)); + } // Store the scalar reductions. - for (int64_t o = 0; o < outputSize; ++o) { - createMem.storeIE(scalarOutputs[o], outputs[o], outputAFs[o]); + for (uint64_t o = 0; o < outputSize; ++o) { + createMem.storeIE(finalResults[o], outputs[o], outputAFs[o]); + } +} + +template +void simdReduce2DIE(const BUILDER &builder, IndexExpr lb, IndexExpr ub, + int64_t VL, bool fullySimd, mlir::Value input, DimsExpr inputAF, + mlir::Value tmp, DimsExpr tmpAF, mlir::Value output, DimsExpr outputAF, + mlir::Value initVal, + /* reduction functions (simd or scalar) */ + SimdReductionBodyFn reductionBodyFn, + /* post reduction functions (simd to scalar + post processing)*/ + SimdPostReductionBodyFn postReductionBodyFn) { + // Expect 2D or more input and tmp. + auto inputType = mlir::cast(input.getType()); + auto tmpType = mlir::cast(tmp.getType()); + uint64_t inputRank = inputType.getRank(); + uint64_t tmpRank = tmpType.getRank(); + assert(inputRank == inputAF.size() && "expected same size"); + assert(tmpRank == tmpAF.size() && "expected same size"); + assert(inputRank >= 2 && "expected rank 2D+"); + assert(tmpRank >= 2 && "expected rank 2D+"); + mlir::Type elementType = inputType.getElementType(); + + // Perform a VL x VL reduction along the innermost 2 dimensions. + // Reuse the simdReduceIE functionality to do so. + llvm::SmallVector newInputs(VL, input); + llvm::SmallVector newInputAFs(VL, inputAF); + llvm::SmallVector newTmps(VL, tmp); + llvm::SmallVector newTmpAFs(VL, tmpAF); + llvm::SmallVector newInitVals(VL, initVal); + llvm::SmallVector, 8> newReductionBodyFnList( + VL, reductionBodyFn); + + // Init the new data structures for VL reductions of VL values + uint64_t inputM2 = inputRank - 2; + uint64_t tmpM2 = tmpRank - 2; + for (int64_t v = 0; v < VL; ++v) { + // Each inputs/tmp is offset by 1 in the second to last dim; + newInputAFs[v][inputM2] = newInputAFs[v][inputM2] + v; + newTmpAFs[v][tmpM2] = newTmpAFs[v][tmpM2] + v; + } + // Step 1: perform the reduction of VL vectors into VL tmps. No output & post + // reduction as we will do it here. + builder.simdReduceIE(lb, ub, VL, fullySimd, newInputs, newInputAFs, newTmps, + newTmpAFs, {}, {}, newInitVals, newReductionBodyFnList, {}); + + // Step 2, perform reduction of VL vectors of VL values into 1 vector of VL. + // Load all temp vectors. + llvm::SmallVector redIn, redOut; + MultiDialectBuilder create(builder); + mlir::VectorType vecType = mlir::VectorType::get({VL}, elementType); + for (int64_t v = 0; v < VL; ++v) { + redIn.emplace_back(create.vec.loadIE(vecType, newTmps[v], newTmpAFs[v])); + } + // Reduce all of the temp vectors at once. + auto redFct = [&](mlir::Value a, mlir::Value b) -> mlir::Value { + return reductionBodyFn(builder, a, b, VL); + }; + create.vec.multiReduction(redIn, redFct, redOut); + // The redOut list should have one value with SIMD of VL. + assert(redOut.size() == 1 && "expected only one val"); + mlir::Value accumulatedVal = redOut[0]; + // Perform post processing (e.g. division by number of elements). + accumulatedVal = postReductionBodyFn(builder, accumulatedVal, VL); + // Store final values. + create.vec.storeIE(accumulatedVal, output, outputAF); +} + +} // namespace impl + +//===----------------------------------------------------------------------===// +// Templates for GenericAffineBuilder +//===----------------------------------------------------------------------===// + +template +mlir::Value GenericAffineBuilder::load(mlir::Value memref, + mlir::ValueRange indices, mlir::ValueRange offsets) const { + return onnx_mlir::impl::load( + *this, memref, indices, offsets); +} + +template +mlir::Value GenericAffineBuilder::loadIE(mlir::Value memref, + mlir::ArrayRef indices, mlir::ValueRange offsets) const { + return onnx_mlir::impl::loadIE( + *this, memref, indices, offsets); +} + +template +inline void GenericAffineBuilder::store(mlir::Value val, + mlir::Value memref, mlir::ValueRange indices, + mlir::ValueRange offsets) const { + onnx_mlir::impl::store( + *this, val, memref, indices, offsets); +} + +template +inline void GenericAffineBuilder::storeIE(mlir::Value val, + mlir::Value memref, mlir::ArrayRef indices, + mlir::ValueRange offsets) const { + onnx_mlir::impl::storeIE( + *this, val, memref, indices, offsets); +} + +template +inline mlir::Operation *GenericAffineBuilder::prefetch( + mlir::Value memref, mlir::AffineMap map, mlir::ValueRange indices, + bool isWrite, unsigned localityHint, bool isDataCache) { + llvm::SmallVector indexArray(indices); + return b().template create( + loc(), memref, map, indexArray, isWrite, localityHint, isDataCache); +} + +template +inline void GenericAffineBuilder::forLoopIE(IndexExpr lb, + IndexExpr ub, int64_t step, bool useParallel, + GenericAffineLoopBodyFn builderFn) const { + // Transform IndexExpressions into value maps and list of + // operands. + mlir::AffineMap lbMap, ubMap; + llvm::SmallVector lbOperands, ubOperands; + lb.getAffineMapAndOperands(lbMap, lbOperands); + ub.getAffineMapAndOperands(ubMap, ubOperands); + + if (useParallel) { + // Create affine parallel for. + llvm::SmallVector types; + llvm::SmallVector reds; + llvm::SmallVector lbs, ubs; + llvm::SmallVector steps; + lbs.emplace_back(lbMap); + ubs.emplace_back(ubMap); + steps.emplace_back(step); + auto parallelLoop = b().template create( + loc(), types, reds, lbs, lbOperands, ubs, ubOperands, steps); + mlir::Block *bodyBlock = parallelLoop.getBody(); + // From extractInductionVars in AffineOps.cpp. + assert(bodyBlock->getNumArguments() == 1 && "expected one loop index"); + mlir::Value index = bodyBlock->getArgument(0); + // Code inspired from AffineForOp::build in AffineOps.cpp. + mlir::OpBuilder::InsertionGuard guard(b()); + b().setInsertionPointToStart(bodyBlock); + GenericAffineBuilder createAffine(b(), loc()); + builderFn(createAffine, {index}); + createAffine.yield(); + } else { + // Create affine for. + b().template create(loc(), lbOperands, lbMap, + ubOperands, ubMap, step, mlir::ValueRange{}, + [&](mlir::OpBuilder &b, mlir::Location loc, mlir::Value index, + mlir::ValueRange args) { + GenericAffineBuilder createAffine(b, loc); + builderFn(createAffine, {index}); + createAffine.yield(); + }); + } +} + +// Sequential only version. +template +inline void GenericAffineBuilder::forLoopIE(IndexExpr lb, + IndexExpr ub, int64_t step, GenericAffineLoopBodyFn builderFn) const { + forLoopIE(lb, ub, step, false /*use parallel*/, builderFn); +} + +template +inline void GenericAffineBuilder::forLoopsIE( + mlir::ArrayRef lbs, mlir::ArrayRef ubs, + mlir::ArrayRef steps, GenericAffineLoopBodyFn builderFn) const { + assert(lbs.size() == ubs.size() && "expected identical sizes"); + assert(lbs.size() == steps.size() && "expected identical sizes"); + llvm::SmallVector loopIndices; + recursionForLoopsIE(lbs, ubs, steps, loopIndices, builderFn); +} + +template +inline void GenericAffineBuilder::simdIterateIE(IndexExpr lb, + IndexExpr ub, int64_t VL, bool fullySimd, bool useParallel, + mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, + mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, + mlir::ArrayRef bodyFnList) const { + onnx_mlir::impl::simdIterateIE, + MemRefBuilder>(*this, lb, ub, VL, fullySimd, useParallel, inputs, + inputAFs, outputs, outputAFs, bodyFnList); +} + +template +inline void GenericAffineBuilder::simdReduceIE(IndexExpr lb, + IndexExpr ub, int64_t VL, bool fullySimd, + mlir::ArrayRef inputs, mlir::ArrayRef inputAFs, + mlir::ArrayRef tmps, mlir::ArrayRef tmpAFs, + mlir::ArrayRef outputs, mlir::ArrayRef outputAFs, + mlir::ArrayRef initVals, + /* reduction function (simd or scalar) */ + mlir::ArrayRef reductionFnList, + /* post reduction function (simd to scalar + post processing)*/ + mlir::ArrayRef postReductionFnList) + const { + onnx_mlir::impl::simdReduceIE, + MemRefBuilder>(*this, lb, ub, VL, fullySimd, inputs, inputAFs, tmps, + tmpAFs, outputs, outputAFs, initVals, reductionFnList, + postReductionFnList); +} + +template +inline void GenericAffineBuilder::simdReduce2DIE( + IndexExpr lb, IndexExpr ub, int64_t VL, bool fullySimd, mlir::Value input, + DimsExpr inputAF, mlir::Value tmp, DimsExpr tmpAF, mlir::Value output, + DimsExpr outputAF, mlir::Value initVal, + /* reduction functions (simd or scalar) */ + GenericAffineSimdReductionBodyFn reductionBodyFn, + /* post reduction functions (post processing ONLY)*/ + GenericAffineSimdPostReductionBodyFn postReductionBodyFn) const { + onnx_mlir::impl::simdReduce2DIE, + MemRefBuilder>(*this, lb, ub, VL, fullySimd, input, inputAF, tmp, tmpAF, + output, outputAF, initVal, reductionBodyFn, postReductionBodyFn); +} + +// This if then else construct has no arguments to the blocks. +template +inline void GenericAffineBuilder::ifThenElseIE( + IndexExprScope &scope, mlir::ArrayRef conditions, + mlir::function_ref thenFn, + mlir::function_ref elseFn) const { + int64_t rank = conditions.size(); + llvm::SmallVector affineCond; + bool allTrue = true; + bool allFalse = true; + for (IndexExpr c : conditions) { + assert(c.isAffine() && "conditions expected to be affine"); + affineCond.emplace_back(c.getAffineExpr()); + if (c.isLiteral()) { + if (c.getLiteral() < 0) // Inequality is expr >= 0, test if false. + allTrue = false; + if (c.getLiteral() >= 0) // Inequality is expr >= 0, test if true. + allFalse = false; + } else { + allTrue = allFalse = false; + } + } + llvm::SmallVector isEq(rank, false); + auto inset = mlir::IntegerSet::get( + scope.getNumDims(), scope.getNumSymbols(), affineCond, isEq); + llvm::SmallVector dimAndSymbolList; + scope.getDimAndSymbolList(dimAndSymbolList); + auto ifOp = b().template create( + loc(), inset, dimAndSymbolList, true); + mlir::Block *thenBlock = ifOp.getThenBlock(); + mlir::Block *elseBlock = ifOp.getElseBlock(); + if (!allFalse) { + appendToBlock(thenBlock, [&](mlir::ValueRange args) { + GenericAffineBuilder createAffine(b(), loc()); + thenFn(createAffine); + }); + } + if (!allTrue) { + appendToBlock(elseBlock, [&](mlir::ValueRange args) { + GenericAffineBuilder createAffine(b(), loc()); + elseFn(createAffine); + }); } } -} // namespace impl \ No newline at end of file +template +mlir::Value GenericAffineBuilder::apply( + mlir::AffineMap map, mlir::ValueRange operands) const { + return b().template create(loc(), map, operands); +} + +template +inline void GenericAffineBuilder::yield() const { + b().template create(loc()); +} + +// Support for multiple forLoopIE loops. +template +void GenericAffineBuilder::recursionForLoopsIE( + mlir::ArrayRef lbs, mlir::ArrayRef ubs, + mlir::ArrayRef steps, + llvm::SmallVectorImpl &loopIndices, + mlir::function_ref + builderFn) const { + int d = loopIndices.size(); + if (d < (int)lbs.size()) { + // Issue a loop and recurse again. + forLoopIE(lbs[d], ubs[d], steps[d], + [&](GenericAffineBuilder &createAffine, mlir::ValueRange loopInd) { + loopIndices.emplace_back(loopInd[0]); + recursionForLoopsIE(lbs, ubs, steps, loopIndices, builderFn); + }); + } else { + // Call lambda function + GenericAffineBuilder createAffine(b(), loc()); + builderFn(createAffine, loopIndices); + } +} + +// Support for adding blocks. +template +inline void GenericAffineBuilder::appendToBlock( + mlir::Block *block, + mlir::function_ref builderFn) const { + mlir::OpBuilder::InsertionGuard guard(b()); + if (block->empty() || + !block->back().mightHaveTrait()) { + b().setInsertionPointToEnd(block); + } else + b().setInsertionPoint(&block->back()); + builderFn(block->getArguments()); +} diff --git a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir index 5233b7488d..ec98797e6b 100644 --- a/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Math/Reduction_with_canonicalize_O3.mlir @@ -343,53 +343,51 @@ func.func private @gpt2_original(%arg0 : tensor) -> tensor // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_2_MEM_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[VAR_12_1_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_26_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_26_]]{{.}} : memref, vector<4xf32> +// CHECK: affine.for [[I_4_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_2_MEM_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_12_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_2_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_30_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_3_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_34_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_4_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_2_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_38_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_5_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_2_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_2_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_2_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_2_MEM_6_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_7_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_8_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_9_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_PARAM_0_MEM_5_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[LOAD_RES_2_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_1_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_1_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_5_]], [[LOAD_RES_2_MEM_10_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_8_]], [[LOAD_RES_2_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_8_]], [[LOAD_RES_2_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_22_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_8_]], [[LOAD_RES_2_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_8_]], [[LOAD_RES_2_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[VAR_21_]], [[VAR_22_]] : vector<4xf32> -// CHECK-DAG: [[VAR_24_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.divf [[VAR_23_]], [[VAR_24_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_24_]], [[VAR_23_]] : vector<4xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_27_:%.+]] = arith.divf [[VAR_25_]], [[VAR_26_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref @@ -456,53 +454,51 @@ func.func private @gpt2_no_keepdims(%arg0 : tensor) -> tensor<*xf32 // CHECK: krnl.store [[VAR_13_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[I_2_]]{{.}} : memref // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[VAR_12_1_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_26_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_26_]]{{.}} : memref, vector<4xf32> +// CHECK: affine.for [[I_4_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_1_MEM_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_12_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_30_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_34_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_1_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_7_]]#0, [[VAR_38_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_5_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_1_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_1_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_1_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_1_MEM_6_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_7_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_8_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_9_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_PARAM_0_MEM_5_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[LOAD_RES_1_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_1_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_1_:%.+]] = arith.addf [[LOAD_PARAM_0_MEM_5_]], [[LOAD_RES_1_MEM_10_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_22_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[VAR_21_]], [[VAR_22_]] : vector<4xf32> -// CHECK-DAG: [[VAR_24_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.divf [[VAR_23_]], [[VAR_24_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_24_]], [[VAR_23_]] : vector<4xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_27_:%.+]] = arith.divf [[VAR_25_]], [[VAR_26_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[RES_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref @@ -580,53 +576,51 @@ func.func private @gpt2_reduce2(%arg0 : tensor) -> tensor<*xf32> { // CHECK: krnl.store [[VAR_13_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[I_2_]]{{.}} : memref // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[VAR_12_1_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_26_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_26_]]{{.}} : memref, vector<4xf32> +// CHECK: affine.for [[I_4_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_3_MEM_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_12_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_30_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_34_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_38_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_5_]], [[LOAD_RES_3_MEM_10_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_22_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[VAR_21_]], [[VAR_22_]] : vector<4xf32> -// CHECK-DAG: [[VAR_24_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.divf [[VAR_23_]], [[VAR_24_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_24_]], [[VAR_23_]] : vector<4xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_27_:%.+]] = arith.divf [[VAR_25_]], [[VAR_26_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref @@ -650,7 +644,9 @@ func.func private @gpt2_one_not_multiple(%arg0 : tensor) -> tensor // CHECK-LABEL: func.func private @gpt2_one_not_multiple // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: [[CST_773_:%.+]] = arith.constant 773 : index // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[CST_776_:%.+]] = arith.constant 776 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index @@ -691,7 +687,7 @@ func.func private @gpt2_one_not_multiple(%arg0 : tensor) -> tensor // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 776){ +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 773){ // CHECK: [[VAR_14_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index // CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[I_2_]], [[VAR_14_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> @@ -704,53 +700,51 @@ func.func private @gpt2_one_not_multiple(%arg0 : tensor) -> tensor // CHECK: krnl.store [[VAR_13_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[I_2_]]{{.}} : memref // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[VAR_12_1_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = [[CST_0_]] to [[CST_776_]]){ -// CHECK: [[VAR_26_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[VAR_26_]]{{.}} : memref, vector<4xf32> +// CHECK: scf.for [[I_4_:%.+]] = [[CST_0_]] to [[CST_773_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_3_MEM_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_12_1_]], [[I_4_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_30_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_34_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_37_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_38_]], [[VAR_26_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_41_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_41_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_39_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_38_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_17_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_6_]], [[LOAD_RES_3_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_17_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_5_]], [[LOAD_RES_3_MEM_10_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_22_:%.+]] = vector.shuffle [[VAR_17_1_]], [[VAR_20_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = vector.shuffle [[VAR_19_]], [[VAR_22_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[VAR_21_]], [[VAR_22_]] : vector<4xf32> -// CHECK-DAG: [[VAR_24_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.divf [[VAR_23_]], [[VAR_24_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_24_]], [[VAR_23_]] : vector<4xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_27_:%.+]] = arith.divf [[VAR_25_]], [[VAR_26_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref @@ -767,12 +761,19 @@ func.func private @gpt2_no_simd_as_not_mult_of_VL(%arg0 : tensor) // mlir2FileCheck.py // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0, s1] -> (s1)> -// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<()[s0, s1, s2] -> (s2)> -// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<()[s0, s1, s2, s3] -> (s3)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0)[s0] -> (-d0 + s0 - 4)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: [[MAP_3_:#.+]] = affine_map<(d0) -> (d0 + 2)> +// CHECK-DAG: [[MAP_4_:#.+]] = affine_map<(d0) -> (d0 + 3)> // CHECK-LABEL: func.func private @gpt2_no_simd_as_not_mult_of_VL // CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> memref { +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: [[CST_872_:%.+]] = arith.constant 872 : index +// CHECK-DAG: [[CST_870_:%.+]] = arith.constant 870 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[CST_873_:%.+]] = arith.constant 873 : index -// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-NOT: separator of consecutive DAGs @@ -787,24 +788,114 @@ func.func private @gpt2_no_simd_as_not_mult_of_VL(%arg0 : tensor) // CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[VAR_dim_]], [[VAR_dim_]]_0 : index // CHECK: [[VAR_3_:%.+]] = arith.floordivsi [[VAR_1_]], [[VAR_2_]] : index // CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_3_]] : index to i64 -// CHECK: [[VAR_5_:%.+]] = arith.sitofp [[VAR_4_]] : i64 to f32 -// CHECK: krnl.memset [[RES_]], [[CST_0_dot_000000_]] : memref -// CHECK-DAG: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.sitofp [[VAR_4_]] : i64 to f32 // CHECK-DAG: [[VAR_dim_3_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_]] : memref // CHECK-DAG: [[VAR_dim_4_:%.+]] = memref.dim [[PARAM_0_]], [[CST_1_]] : memref -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_3_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_3_]], [[VAR_dim_4_]]{{.}}, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 97, [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 9){ -// CHECK: [[VAR_8_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[VAR_8_]]#2, [[VAR_8_]]#3] : memref -// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[CST_0_]], [[CST_0_]]{{.}} : memref -// CHECK: [[VAR_11_:%.+]] = arith.addf [[LOAD_RES_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 -// CHECK: krnl.store [[VAR_11_]], [[RES_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[CST_0_]], [[CST_0_]]{{.}} : memref -// CHECK: } -// CHECK: [[LOOP_1_:%.+]]:4 = krnl.define_loops 4 -// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to [[MAP_1_]](){{.}}[[VAR_dim_3_]], [[VAR_dim_4_]], [[VAR_dim_]]{{.}}, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to [[MAP_2_]](){{.}}[[VAR_dim_3_]], [[VAR_dim_4_]], [[VAR_dim_]], [[VAR_dim_]]_0], [[LOOP_1_]]#2 -> [[I_6_:%.+]] = 0 to 1, [[LOOP_1_]]#3 -> [[I_7_:%.+]] = 0 to 1){ -// CHECK: [[VAR_8_1_:%.+]]:4 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) : (!krnl.loop, !krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index, index) -// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[RES_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[VAR_8_1_]]#2, [[VAR_8_1_]]#3] : memref -// CHECK: [[LOAD_RES_MEM_2_:%.+]] = arith.divf [[LOAD_RES_MEM_1_]], [[VAR_5_]] : f32 -// CHECK: krnl.store [[LOAD_RES_MEM_2_]], [[RES_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[VAR_8_1_]]#2, [[VAR_8_1_]]#3] : memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[VAR_dim_3_]], [[RES_1_]][0] : memref<3xindex> +// CHECK: affine.store [[VAR_dim_4_]], [[RES_1_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_873_]], [[RES_1_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref, memref<3xindex>) -> memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> +// CHECK: affine.store [[VAR_dim_]], [[RES_2_]][0] : memref<2xindex> +// CHECK: affine.store [[VAR_dim_0_]], [[RES_2_]][1] : memref<2xindex> +// CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_]]([[RES_]]_6) : (memref, memref<2xindex>) -> memref +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to [[MAP_0_]](){{.}}[[VAR_dim_]], [[VAR_dim_]]_0]){ +// CHECK-DAG: [[VAR_7_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK: [[VAR_8_:%.+]] = affine.apply [[MAP_1_]]([[VAR_7_]]#1){{.}}[[VAR_dim_0_]]{{.}} +// CHECK: [[VAR_9_:%.+]] = arith.cmpi slt, [[VAR_8_]], [[CST_0_]] : index +// CHECK: scf.if [[VAR_9_]] { +// CHECK: scf.for [[I_2_:%.+]] = [[VAR_7_]]#1 to [[VAR_dim_0_]] step [[CST_1_]] { +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 870){ +// CHECK: [[VAR_15_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[I_2_]], [[VAR_15_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_18_:%.+]] = arith.addf [[LOAD_RES_3_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<4xf32> +// CHECK: vector.store [[VAR_18_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 872 to 873){ +// CHECK: [[VAR_15_1_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[I_2_]], [[VAR_15_1_]]{{.}} : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: [[VAR_18_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_18_1_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: } +// CHECK: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: [[VAR_13_:%.+]] = vector.reduction , [[LOAD_RES_3_MEM_2_]] : vector<4xf32> into f32 +// CHECK: [[VAR_14_:%.+]] = arith.divf [[VAR_13_]], [[VAR_5_]] : f32 +// CHECK: krnl.store [[VAR_14_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[I_2_]]{{.}} : memref +// CHECK: } +// CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOOP_2_:%.+]] = affine.apply [[MAP_3_]]([[VAR_7_]]#1) +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = affine.apply [[MAP_4_]]([[VAR_7_]]#1) +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: scf.for [[I_5_:%.+]] = [[CST_0_]] to [[CST_870_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[I_5_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[I_5_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_2_]], [[I_5_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_3_MEM_2_]], [[I_5_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_50_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_4_]] : vector<4xf32> +// CHECK-DAG: [[VAR_51_:%.+]] = arith.addf [[LOAD_RES_3_MEM_6_]], [[LOAD_VAR_reshape_MEM_5_]] : vector<4xf32> +// CHECK: vector.store [[VAR_48_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_49_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_50_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_51_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_6_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1, [[CST_872_]]{{.}} : memref +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_7_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_1_]], [[CST_872_]]{{.}} : memref +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_8_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOOP_2_]], [[CST_872_]]{{.}} : memref +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_9_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_7_]]#0, [[LOAD_RES_3_MEM_2_]], [[CST_872_]]{{.}} : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_21_:%.+]] = arith.addf [[LOAD_RES_3_MEM_7_]], [[LOAD_VAR_reshape_MEM_6_]] : f32 +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[LOAD_RES_3_MEM_8_]], [[LOAD_VAR_reshape_MEM_7_]] : f32 +// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[LOAD_RES_3_MEM_9_]], [[LOAD_VAR_reshape_MEM_8_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.addf [[LOAD_RES_3_MEM_10_]], [[LOAD_VAR_reshape_MEM_9_]] : f32 +// CHECK: memref.store [[VAR_21_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_22_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_23_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_24_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_11_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_12_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_13_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_14_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_29_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_11_]], [[LOAD_RES_3_MEM_12_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_11_]], [[LOAD_RES_3_MEM_12_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_31_:%.+]] = arith.addf [[VAR_30_]], [[VAR_29_]] : vector<4xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_13_]], [[LOAD_RES_3_MEM_14_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_33_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_13_]], [[LOAD_RES_3_MEM_14_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_34_:%.+]] = arith.addf [[VAR_33_]], [[VAR_32_]] : vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = vector.shuffle [[VAR_31_]], [[VAR_34_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_36_:%.+]] = vector.shuffle [[VAR_31_]], [[VAR_34_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_37_:%.+]] = arith.addf [[VAR_36_]], [[VAR_35_]] : vector<4xf32> +// CHECK-DAG: [[VAR_38_:%.+]] = vector.splat [[VAR_5_]] : vector<4xf32> +// CHECK: [[VAR_39_:%.+]] = arith.divf [[VAR_37_]], [[VAR_38_]] : vector<4xf32> +// CHECK: vector.store [[VAR_39_]], [[VAR_reshape_7_]]{{.}}[[VAR_7_]]#0, [[VAR_7_]]#1] : memref, vector<4xf32> +// CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -826,59 +917,57 @@ func.func private @test_reducemax_v13_bis(%arg0 : tensor<1028x256xf32>) -> tenso // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0xFF800000> : vector<4xf32> // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1028xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) // CHECK: krnl.iterate([[BLOCK_TILE__0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 1028){ // CHECK-DAG: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__0_]]) : (!krnl.loop) -> index // CHECK-DAG: [[RES_1_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]) +// CHECK-DAG: [[VAR_3_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]) +// CHECK-DAG: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]) // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_1_:%.+]] = [[CST_0_]] to [[CST_256_]]){ -// CHECK: [[VAR_16_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]], [[VAR_1_]]6] : memref<1028x256xf32>, vector<4xf32> +// CHECK: affine.for [[I_1_:%.+]] = 0 to 256 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]], [[I_1_]]{{.}} : memref<1028x256xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_2_]], [[I_1_]]{{.}} : memref<1028x256xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_3_]], [[I_1_]]{{.}} : memref<1028x256xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_4_]], [[I_1_]]{{.}} : memref<1028x256xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> -// CHECK: vector.store [[VAR_19_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_20_]], [[VAR_16_]]{{.}} : memref<1028x256xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_23_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_23_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_24_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_24_]], [[VAR_16_]]{{.}} : memref<1028x256xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_27_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_27_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_28_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_28_]], [[VAR_16_]]{{.}} : memref<1028x256xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_31_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_31_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_26_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_28_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_5_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_6_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_7_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_4_]], [[LOAD_RES_1_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]] = arith.maxnumf [[VAR_7_]], [[VAR_8_]] : vector<4xf32> -// CHECK-DAG: [[VAR_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_12_:%.+]] = arith.maxnumf [[VAR_10_]], [[VAR_11_]] : vector<4xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[VAR_9_]], [[VAR_12_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_14_:%.+]] = vector.shuffle [[VAR_9_]], [[VAR_12_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_15_:%.+]] = arith.maxnumf [[VAR_13_]], [[VAR_14_]] : vector<4xf32> -// CHECK: vector.store [[VAR_15_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<1028xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.maxnumf [[VAR_10_]], [[VAR_9_]] : vector<4xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_13_]], [[VAR_12_]] : vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_17_:%.+]] = arith.maxnumf [[VAR_16_]], [[VAR_15_]] : vector<4xf32> +// CHECK: vector.store [[VAR_17_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<1028xf32>, vector<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref<1028xf32> // CHECK: } @@ -904,7 +993,6 @@ func.func private @test_reducemax_v13_small(%arg0 : tensor<7x8xf32>) -> tensor<* // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index // CHECK-DAG: [[CST_7_:%.+]] = arith.constant 7 : index -// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<7xf32> // CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 // CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) @@ -930,50 +1018,48 @@ func.func private @test_reducemax_v13_small(%arg0 : tensor<7x8xf32>) -> tensor<* // CHECK: krnl.store [[VAR_6_]], [[RES_]]{{.}}[[I_1_]]{{.}} : memref<7xf32> // CHECK: } // CHECK: } else { +// CHECK-DAG: [[LOOP_1_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]) +// CHECK-DAG: [[LOAD_RES_1_MEM_1_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]) +// CHECK-DAG: [[VAR_6_1_:%.+]] = affine.apply [[MAP_3_]]([[VAR_1_]]) // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_3_:%.+]] = [[CST_0_]] to [[CST_8_]]){ -// CHECK: [[VAR_18_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]], [[VAR_1_]]8] : memref<7x8xf32>, vector<4xf32> +// CHECK: affine.for [[I_3_:%.+]] = 0 to 8 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]], [[I_3_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[LOOP_1_]], [[I_3_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[LOAD_RES_1_MEM_1_]], [[I_3_]]{{.}} : memref<7x8xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_1_]], [[I_3_]]{{.}} : memref<7x8xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_2_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_21_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_21_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_22_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_22_]], [[VAR_18_]]{{.}} : memref<7x8xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_3_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_25_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_26_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_26_]], [[VAR_18_]]{{.}} : memref<7x8xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_4_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = affine.apply [[MAP_3_]]([[VAR_1_]]) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_4_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_30_]], [[VAR_18_]]{{.}} : memref<7x8xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_5_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_33_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> -// CHECK: vector.store [[VAR_33_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_2_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_3_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_4_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_5_]], [[LOAD_PARAM_0_MEM_4_]] : vector<4xf32> +// CHECK: vector.store [[VAR_28_]], [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_29_]], [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_30_]], [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_31_]], [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_1_MEM_6_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_7_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_8_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_1_MEM_9_:%.+]] = vector.load [[RES_1_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[LOAD_RES_1_MEM_10_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_10_1_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_6_]], [[LOAD_RES_1_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_11_:%.+]] = arith.maxnumf [[LOAD_RES_1_MEM_10_]], [[VAR_10_1_]] : vector<4xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_14_:%.+]] = arith.maxnumf [[VAR_12_]], [[VAR_13_]] : vector<4xf32> -// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_17_:%.+]] = arith.maxnumf [[VAR_15_]], [[VAR_16_]] : vector<4xf32> -// CHECK: vector.store [[VAR_17_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<7xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = arith.maxnumf [[VAR_12_]], [[VAR_11_]] : vector<4xf32> +// CHECK-DAG: [[VAR_14_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[LOAD_RES_1_MEM_8_]], [[LOAD_RES_1_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_16_:%.+]] = arith.maxnumf [[VAR_15_]], [[VAR_14_]] : vector<4xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = vector.shuffle [[VAR_13_]], [[VAR_16_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[VAR_13_]], [[VAR_16_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[VAR_18_]], [[VAR_17_]] : vector<4xf32> +// CHECK: vector.store [[VAR_19_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<7xf32>, vector<4xf32> // CHECK: } // CHECK: } // CHECK: return [[RES_]] : memref<7xf32> @@ -1035,7 +1121,6 @@ func.func private @bertsquad10_same_pattern(%arg0 : tensor) -> te // CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> // CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_768_:%.+]] = arith.constant 768 : index // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index @@ -1057,53 +1142,52 @@ func.func private @bertsquad10_same_pattern(%arg0 : tensor) -> te // CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){ // CHECK-DAG: [[VAR_6_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) // CHECK-DAG: [[RES_2_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = affine.apply [[MAP_2_]]([[VAR_6_]]#1) +// CHECK-DAG: [[VAR_8_:%.+]] = affine.apply [[MAP_3_]]([[VAR_6_]]#1) +// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_4_]]([[VAR_6_]]#1) // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_23_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1, [[VAR_23_]]{{.}} : memref, vector<4xf32> +// CHECK: affine.for [[I_2_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1, [[I_2_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_7_]], [[I_2_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_8_]], [[I_2_]]{{.}} : memref, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_9_]], [[I_2_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_26_:%.+]] = arith.addf [[LOAD_RES_2_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> -// CHECK: vector.store [[VAR_26_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_27_:%.+]] = affine.apply [[MAP_2_]]([[VAR_6_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_27_]], [[VAR_23_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_1_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_30_:%.+]] = arith.addf [[LOAD_RES_2_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_30_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_31_:%.+]] = affine.apply [[MAP_3_]]([[VAR_6_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_31_]], [[VAR_23_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_2_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_34_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_34_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_35_:%.+]] = affine.apply [[MAP_4_]]([[VAR_6_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_6_]]#0, [[VAR_35_]], [[VAR_23_]]{{.}} : memref, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_3_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_38_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_38_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_33_:%.+]] = arith.addf [[LOAD_RES_2_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> +// CHECK-DAG: [[VAR_34_:%.+]] = arith.addf [[LOAD_RES_2_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_36_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_33_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_34_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_35_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_36_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_2_MEM_4_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_5_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_6_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_7_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_14_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_14_:%.+]] = arith.addf [[VAR_12_]], [[VAR_13_]] : vector<4xf32> -// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_17_:%.+]] = arith.addf [[VAR_15_]], [[VAR_16_]] : vector<4xf32> -// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[VAR_14_]], [[VAR_17_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_19_:%.+]] = vector.shuffle [[VAR_14_]], [[VAR_17_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addf [[VAR_15_]], [[VAR_14_]] : vector<4xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_19_:%.+]] = arith.addf [[VAR_18_]], [[VAR_17_]] : vector<4xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = vector.shuffle [[VAR_16_]], [[VAR_19_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_21_:%.+]] = vector.shuffle [[VAR_16_]], [[VAR_19_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_20_:%.+]] = arith.addf [[VAR_18_]], [[VAR_19_]] : vector<4xf32> -// CHECK-DAG: [[VAR_21_:%.+]] = vector.splat [[VAR_4_]] : vector<4xf32> -// CHECK: [[VAR_22_:%.+]] = arith.divf [[VAR_20_]], [[VAR_21_]] : vector<4xf32> -// CHECK: vector.store [[VAR_22_]], [[VAR_reshape_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1] : memref, vector<4xf32> +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[VAR_21_]], [[VAR_20_]] : vector<4xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = vector.splat [[VAR_4_]] : vector<4xf32> +// CHECK: [[VAR_24_:%.+]] = arith.divf [[VAR_22_]], [[VAR_23_]] : vector<4xf32> +// CHECK: vector.store [[VAR_24_]], [[VAR_reshape_]]{{.}}[[VAR_6_]]#0, [[VAR_6_]]#1] : memref, vector<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref // CHECK: } @@ -1131,7 +1215,6 @@ func.func private @bertsquad10_const_pattern(%arg0 : tensor<1x256x768xf32>) -> t // CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index // CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index -// CHECK-DAG: [[CST_768_:%.+]] = arith.constant 768 : index // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<1x256x1xf32> // CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> // CHECK: affine.store [[CST_1_]], [[RES_1_]][0] : memref<2xindex> @@ -1142,51 +1225,50 @@ func.func private @bertsquad10_const_pattern(%arg0 : tensor<1x256x768xf32>) -> t // CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 1, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 256){ // CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) // CHECK-DAG: [[RES_2_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]#1) +// CHECK-DAG: [[VAR_3_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#1) +// CHECK-DAG: [[VAR_4_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#1) // CHECK: vector.store [[VAR_cst_0_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_0_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_0_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: vector.store [[VAR_cst_0_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 -// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) -// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_2_:%.+]] = [[CST_0_]] to [[CST_768_]]){ -// CHECK: [[VAR_17_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index -// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[VAR_1_]]7] : memref<1x256x768xf32>, vector<4xf32> +// CHECK: affine.for [[I_2_:%.+]] = 0 to 768 step 4 { +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1, [[I_2_]]{{.}} : memref<1x256x768xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_2_]], [[I_2_]]{{.}} : memref<1x256x768xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_3_]], [[I_2_]]{{.}} : memref<1x256x768xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_4_]], [[I_2_]]{{.}} : memref<1x256x768xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_20_:%.+]] = arith.addf [[LOAD_RES_2_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> -// CHECK: vector.store [[VAR_20_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_21_:%.+]] = affine.apply [[MAP_0_]]([[VAR_1_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_21_]], [[VAR_1_]]7] : memref<1x256x768xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_1_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_24_:%.+]] = arith.addf [[LOAD_RES_2_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> -// CHECK: vector.store [[VAR_24_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_25_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_2_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_25_]], [[VAR_1_]]7] : memref<1x256x768xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_2_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_28_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> -// CHECK: vector.store [[VAR_28_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_29_:%.+]] = affine.apply [[MAP_2_]]([[VAR_1_]]#1) -// CHECK-DAG: [[LOAD_PARAM_0_MEM_3_:%.+]] = vector.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_29_]], [[VAR_1_]]7] : memref<1x256x768xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_3_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> -// CHECK: [[VAR_32_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> -// CHECK: vector.store [[VAR_32_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_27_:%.+]] = arith.addf [[LOAD_RES_2_MEM_]], [[LOAD_PARAM_0_MEM_]] : vector<4xf32> +// CHECK-DAG: [[VAR_28_:%.+]] = arith.addf [[LOAD_RES_2_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = arith.addf [[LOAD_RES_2_MEM_2_]], [[LOAD_PARAM_0_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = arith.addf [[LOAD_RES_2_MEM_3_]], [[LOAD_PARAM_0_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_27_]], [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_28_]], [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_29_]], [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_30_]], [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK: } // CHECK-DAG: [[LOAD_RES_2_MEM_4_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_5_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_6_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-DAG: [[LOAD_RES_2_MEM_7_:%.+]] = vector.load [[RES_2_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_7_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_8_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_9_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_4_]], [[LOAD_RES_2_MEM_5_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> // CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_7_]], [[VAR_8_]] : vector<4xf32> -// CHECK-DAG: [[VAR_10_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_12_:%.+]] = arith.addf [[VAR_10_]], [[VAR_11_]] : vector<4xf32> -// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[VAR_9_]], [[VAR_12_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> -// CHECK-DAG: [[VAR_14_:%.+]] = vector.shuffle [[VAR_9_]], [[VAR_12_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> -// CHECK: [[VAR_15_:%.+]] = arith.addf [[VAR_13_]], [[VAR_14_]] : vector<4xf32> -// CHECK: [[VAR_16_:%.+]] = arith.divf [[VAR_15_]], [[VAR_cst_]] : vector<4xf32> -// CHECK: vector.store [[VAR_16_]], [[VAR_reshape_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<1x256xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = arith.addf [[VAR_10_]], [[VAR_9_]] : vector<4xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_13_:%.+]] = vector.shuffle [[LOAD_RES_2_MEM_6_]], [[LOAD_RES_2_MEM_7_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_14_:%.+]] = arith.addf [[VAR_13_]], [[VAR_12_]] : vector<4xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_16_:%.+]] = vector.shuffle [[VAR_11_]], [[VAR_14_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_17_:%.+]] = arith.addf [[VAR_16_]], [[VAR_15_]] : vector<4xf32> +// CHECK: [[VAR_18_:%.+]] = arith.divf [[VAR_17_]], [[VAR_cst_]] : vector<4xf32> +// CHECK: vector.store [[VAR_18_]], [[VAR_reshape_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<1x256xf32>, vector<4xf32> // CHECK: } // CHECK: return [[RES_]] : memref<1x256x1xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_canonicalize.mlir index 149c8e7c2a..3a5354b54d 100644 --- a/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/NN/Normalization_O3_SIMD_canonicalize.mlir @@ -23,7 +23,7 @@ func.func @layernorm_4D_with_scale_bias(%arg0: tensor<2x64x32x8xf32>, %arg1: ten // ----- -// collapsed range is not a multiple of 4, cannot do simd +// collapsed range is not a multiple of 4, cannot do simd: Update, it is now supported. func.func @layernorm_4D_with_scale_bias_no_SIMD(%arg0: tensor<2x64x31x3xf32>, %arg1: tensor<31x3xf32>, %arg2: tensor<31x3xf32>) -> tensor<*xf32> { %0 = "onnx.NoValue"() {value} : () -> none @@ -31,12 +31,445 @@ func.func @layernorm_4D_with_scale_bias_no_SIMD(%arg0: tensor<2x64x31x3xf32>, %a onnx.Return %Y : tensor<*xf32> // mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 + 2)> +// CHECK-DAG: [[MAP_2_:#.+]] = affine_map<(d0) -> (d0 + 3)> // CHECK-LABEL: func.func @layernorm_4D_with_scale_bias_no_SIMD // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<2x64x31x3xf32>, [[PARAM_1_:%.+]]: memref<31x3xf32>, [[PARAM_2_:%.+]]: memref<31x3xf32>) -> memref<2x64x31x3xf32> { -// CHECK: [[LOOP_0_:%.+]]:4 = krnl.define_loops 4 -// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2, [[LOOP_0_]]#3) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 2, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 64, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 31, [[LOOP_0_]]#3 -> [[I_3_:%.+]] = 0 to 3){ -// CHECK: [[LOOP_1_:%.+]]:4 = krnl.define_loops 4 -// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1, [[LOOP_1_]]#2, [[LOOP_1_]]#3) with ([[LOOP_1_]]#0 -> [[I_4_:%.+]] = 0 to 2, [[LOOP_1_]]#1 -> [[I_5_:%.+]] = 0 to 64, [[LOOP_1_]]#2 -> [[I_6_:%.+]] = 0 to 1, [[LOOP_1_]]#3 -> [[I_7_:%.+]] = 0 to 1){ +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<1.000000e+00> : vector<32xf32> +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<9.300000e+01> : vector<4xf32> +// CHECK-DAG: [[VAR_cst_1_:%.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK-DAG: [[CST_92_:%.+]] = arith.constant 92 : index +// CHECK-DAG: [[CST_90_:%.+]] = arith.constant 90 : index +// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_93_:%.+]] = arith.constant 93 : index +// CHECK-DAG: [[CST_11904_:%.+]] = arith.constant 11904 : index +// CHECK-DAG: [[CST_128_:%.+]] = arith.constant 128 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : index +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[VAR_0_:%.+]] = "krnl.global"() {name = "constant_{{[0-9]+}}", shape = [1], value = dense<9.99999974E-6> : tensor<1xf32>} : () -> memref<1xf32> +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_1_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_1_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_1_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_1_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> +// CHECK: affine.store [[CST_2_]], [[RES_2_]][0] : memref<2xindex> +// CHECK: affine.store [[CST_64_]], [[RES_2_]][1] : memref<2xindex> +// CHECK-DAG: [[VAR_reshape_4_:%.+]] = memref.reshape [[RES_]]([[RES_]]_3) : (memref<2x64x1x1xf32>, memref<2xindex>) -> memref<2x64xf32> +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK: [[BLOCK_TILE__0_:%.+]], [[BLOCK_IN__0_:%.+]] = krnl.block [[LOOP_0_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 2, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[BLOCK_TILE__0_]]) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_0_]]([[VAR_8_]]#1) +// CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_1_]]([[VAR_8_]]#1) +// CHECK-DAG: [[VAR_11_:%.+]] = affine.apply [[MAP_2_]]([[VAR_8_]]#1) +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: scf.for [[I_2_:%.+]] = [[CST_0_]] to [[CST_90_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_9_]], [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_10_]], [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_11_]], [[I_2_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_46_:%.+]] = arith.addf [[LOAD_RES_3_MEM_]], [[LOAD_VAR_reshape_MEM_]] : vector<4xf32> +// CHECK-DAG: [[VAR_47_:%.+]] = arith.addf [[LOAD_RES_3_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_48_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_49_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_46_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_47_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_48_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_49_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1, [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_9_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_6_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_10_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_7_:%.+]] = memref.load [[VAR_reshape_]]{{.}}[[VAR_8_]]#0, [[VAR_11_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = memref.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_4_]] : f32 +// CHECK-DAG: [[VAR_21_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_5_]] : f32 +// CHECK-DAG: [[VAR_22_:%.+]] = arith.addf [[LOAD_RES_3_MEM_6_]], [[LOAD_VAR_reshape_MEM_6_]] : f32 +// CHECK-DAG: [[VAR_23_:%.+]] = arith.addf [[LOAD_RES_3_MEM_7_]], [[LOAD_VAR_reshape_MEM_7_]] : f32 +// CHECK: memref.store [[VAR_20_]], [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_21_]], [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_22_]], [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_23_]], [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_11_:%.+]] = vector.load [[RES_3_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_29_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_30_:%.+]] = arith.addf [[VAR_29_]], [[VAR_28_]] : vector<4xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_32_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_33_:%.+]] = arith.addf [[VAR_32_]], [[VAR_31_]] : vector<4xf32> +// CHECK-DAG: [[VAR_34_:%.+]] = vector.shuffle [[VAR_30_]], [[VAR_33_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_35_:%.+]] = vector.shuffle [[VAR_30_]], [[VAR_33_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_36_:%.+]] = arith.addf [[VAR_35_]], [[VAR_34_]] : vector<4xf32> +// CHECK: [[VAR_37_:%.+]] = arith.divf [[VAR_36_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK: vector.store [[VAR_37_]], [[VAR_reshape_4_]]{{.}}[[VAR_8_]]#0, [[VAR_8_]]#1] : memref<2x64xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[RES_4_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_5_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_7_:%.+]] = memref.reshape [[RES_]]([[RES_]]_6) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_6_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_9_:%.+]] = memref.reshape [[RES_]]([[RES_]]_8) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_7_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_7_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_11_:%.+]] = memref.reshape [[RES_4_]]([[RES_7_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_1_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__1_:%.+]], [[BLOCK_IN__1_:%.+]] = krnl.block [[LOOP_1_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__1_]]) with ([[LOOP_1_]] -> [[I_3_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_1_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__1_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_1_:%.+]] = vector.load [[VAR_reshape_7_]]{{.}}[[VAR_9_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_1_:%.+]] = vector.load [[VAR_reshape_9_]]{{.}}[[VAR_9_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_:%.+]] = arith.mulf [[VAR_10_1_]], [[VAR_11_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_4_]], [[VAR_reshape_11_]]{{.}}[[VAR_9_1_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_8_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_9_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_11904_]], [[RES_9_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_9_]]) : (memref<2x64x31x3xf32>, memref<1xindex>) -> memref<11904xf32> +// CHECK-DAG: [[RES_10_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_11904_]], [[RES_10_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_16_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_10_]]) : (memref<2x64x31x3xf32>, memref<1xindex>) -> memref<11904xf32> +// CHECK-DAG: [[RES_11_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_11904_]], [[RES_11_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_18_:%.+]] = memref.reshape [[RES_8_]]([[RES_11_]]) : (memref<2x64x31x3xf32>, memref<1xindex>) -> memref<11904xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__2_:%.+]], [[BLOCK_IN__2_:%.+]] = krnl.block [[LOOP_2_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to 11904){ +// CHECK: [[VAR_9_2_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__2_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_1_:%.+]] = vector.load [[VAR_reshape_14_]]{{.}}[[VAR_9_2_]]{{.}} : memref<11904xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_1_:%.+]] = vector.load [[VAR_reshape_16_]]{{.}}[[VAR_9_2_]]{{.}} : memref<11904xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_1_:%.+]] = arith.mulf [[VAR_10_1_]], [[VAR_11_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_4_1_]], [[VAR_reshape_18_]]{{.}}[[VAR_9_2_]]{{.}} : memref<11904xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_12_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_13_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_13_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_13_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_13_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_21_:%.+]] = memref.reshape [[RES_8_]]([[RES_13_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_14_:%.+]] = memref.alloc() {{.*}}: memref<2xindex> +// CHECK: affine.store [[CST_2_]], [[RES_14_]][0] : memref<2xindex> +// CHECK: affine.store [[CST_64_]], [[RES_14_]][1] : memref<2xindex> +// CHECK-DAG: [[VAR_reshape_23_:%.+]] = memref.reshape [[RES_12_]]([[RES_14_]]) : (memref<2x64x1x1xf32>, memref<2xindex>) -> memref<2x64xf32> +// CHECK-DAG: [[LOOP_3_:%.+]]:2 = krnl.define_loops 2 +// CHECK: [[BLOCK_TILE__3_:%.+]], [[BLOCK_IN__3_:%.+]] = krnl.block [[LOOP_3_]]#1 4 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[LOOP_3_]]#0, [[BLOCK_TILE__3_]]) with ([[LOOP_3_]]#0 -> [[I_5_:%.+]] = 0 to 2, [[LOOP_3_]]#1 -> [[I_6_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_3_]]#0, [[BLOCK_TILE__3_]]) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[RES_15_:%.+]] = memref.alloca() {{.*}}: memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_3_:%.+]] = affine.apply [[MAP_0_]]([[VAR_8_1_]]#1) +// CHECK-DAG: [[VAR_10_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_8_1_]]#1) +// CHECK-DAG: [[VAR_11_2_:%.+]] = affine.apply [[MAP_2_]]([[VAR_8_1_]]#1) +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_cst_1_]], [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: scf.for [[I_7_:%.+]] = [[CST_0_]] to [[CST_90_]] step [[CST_4_]] { +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_8_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_1_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_9_3_]], [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_2_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_10_2_]], [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_3_:%.+]] = vector.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_11_2_]], [[I_7_]]{{.}} : memref<2x64x93xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_12_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_3_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_46_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_12_]], [[LOAD_VAR_reshape_MEM_8_]] : vector<4xf32> +// CHECK-DAG: [[VAR_47_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_1_]], [[LOAD_VAR_reshape_MEM_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_48_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_2_]], [[LOAD_VAR_reshape_MEM_2_]] : vector<4xf32> +// CHECK-DAG: [[VAR_49_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_3_]], [[LOAD_VAR_reshape_MEM_3_]] : vector<4xf32> +// CHECK: vector.store [[VAR_46_1_]], [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_47_1_]], [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_48_1_]], [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: vector.store [[VAR_49_1_]], [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1, [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_9_3_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_6_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_10_2_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_7_:%.+]] = memref.load [[VAR_reshape_21_]]{{.}}[[VAR_8_1_]]#0, [[VAR_11_2_]], [[CST_92_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_4_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_5_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_6_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_7_:%.+]] = memref.load [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_20_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_4_]], [[LOAD_VAR_reshape_MEM_4_1_]] : f32 +// CHECK-DAG: [[VAR_21_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_5_]], [[LOAD_VAR_reshape_MEM_5_]] : f32 +// CHECK-DAG: [[VAR_22_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_6_]], [[LOAD_VAR_reshape_MEM_6_]] : f32 +// CHECK-DAG: [[VAR_23_1_:%.+]] = arith.addf [[LOAD_RES_3_MEM_7_]], [[LOAD_VAR_reshape_MEM_7_]] : f32 +// CHECK: memref.store [[VAR_20_1_]], [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_21_1_]], [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_22_1_]], [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK: memref.store [[VAR_23_1_]], [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_8_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_0_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_9_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_1_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_10_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_2_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-DAG: [[LOAD_RES_3_MEM_11_:%.+]] = vector.load [[RES_15_]]{{.}}[[CST_3_]], [[CST_0_]]{{.}} : memref<4x4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_28_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_29_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_8_]], [[LOAD_RES_3_MEM_9_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_30_1_:%.+]] = arith.addf [[VAR_29_1_]], [[VAR_28_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_31_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [0, 4, 1, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_32_1_:%.+]] = vector.shuffle [[LOAD_RES_3_MEM_10_]], [[LOAD_RES_3_MEM_11_]] [2, 6, 3, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_33_1_:%.+]] = arith.addf [[VAR_32_1_]], [[VAR_31_1_]] : vector<4xf32> +// CHECK-DAG: [[VAR_34_1_:%.+]] = vector.shuffle [[VAR_30_1_]], [[VAR_33_1_]] [0, 1, 4, 5] : vector<4xf32>, vector<4xf32> +// CHECK-DAG: [[VAR_35_1_:%.+]] = vector.shuffle [[VAR_30_1_]], [[VAR_33_1_]] [2, 3, 6, 7] : vector<4xf32>, vector<4xf32> +// CHECK: [[VAR_36_1_:%.+]] = arith.addf [[VAR_35_1_]], [[VAR_34_1_]] : vector<4xf32> +// CHECK: [[VAR_37_1_:%.+]] = arith.divf [[VAR_36_1_]], [[VAR_cst_0_]] : vector<4xf32> +// CHECK: vector.store [[VAR_37_1_]], [[VAR_reshape_23_]]{{.}}[[VAR_8_1_]]#0, [[VAR_8_1_]]#1] : memref<2x64xf32>, vector<4xf32> +// CHECK: } +// CHECK-DAG: [[RES_16_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_17_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_17_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_28_:%.+]] = memref.reshape [[RES_12_]]([[RES_17_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_18_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_18_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_30_:%.+]] = memref.reshape [[RES_4_]]([[RES_18_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_19_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_19_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_32_:%.+]] = memref.reshape [[RES_16_]]([[RES_19_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_4_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__4_:%.+]], [[BLOCK_IN__4_:%.+]] = krnl.block [[LOOP_4_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__4_]]) with ([[LOOP_4_]] -> [[I_8_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_4_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__4_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[VAR_10_2_:%.+]] = vector.load [[VAR_reshape_28_]]{{.}}[[VAR_9_4_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_2_:%.+]] = vector.load [[VAR_reshape_30_]]{{.}}[[VAR_9_4_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_1_1_:%.+]] = arith.subf [[VAR_10_2_]], [[VAR_11_2_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_4_1_1_]], [[VAR_reshape_32_]]{{.}}[[VAR_9_4_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_20_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_21_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_21_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_21_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_21_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_35_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_21_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_22_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_22_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_22_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_1_]], [[RES_22_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_37_:%.+]] = memref.reshape [[RES_]]([[RES_]]_36) : (memref<2x64x1x1xf32>, memref<3xindex>) -> memref<2x64x1xf32> +// CHECK-DAG: [[RES_23_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_23_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_23_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_23_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_39_:%.+]] = memref.reshape [[RES_20_]]([[RES_23_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_5_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_5_]]#0, [[LOOP_5_]]#1) with ([[LOOP_5_]]#0 -> [[I_9_:%.+]] = 0 to 2, [[LOOP_5_]]#1 -> [[I_10_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_2_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_5_]]#0, [[LOOP_5_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_6_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__5_:%.+]], [[BLOCK_IN__5_:%.+]] = krnl.block [[LOOP_6_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__5_]]) with ([[LOOP_6_]] -> [[I_11_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_3_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__5_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_:%.+]] = vector.load [[VAR_reshape_35_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_3_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_:%.+]] = krnl.load [[VAR_reshape_37_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_:%.+]] = vector.splat [[LOAD_VAR_reshape_MEM_5_1_]] : vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_7_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_4_1_1_]], [[LOAD_VAR_reshape_MEM_6_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_7_1_]], [[VAR_reshape_39_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_3_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_7_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_7_]]) with ([[LOOP_7_]] -> [[I_12_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_4_:%.+]] = krnl.get_induction_var_value([[LOOP_7_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_:%.+]] = krnl.load [[VAR_reshape_35_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_4_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_:%.+]] = krnl.load [[VAR_reshape_37_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_:%.+]] = arith.subf [[LOAD_VAR_reshape_MEM_4_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_]], [[VAR_reshape_39_]]{{.}}[[VAR_8_2_]]#0, [[VAR_8_2_]]#1, [[VAR_11_4_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_24_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_25_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_25_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_42_:%.+]] = memref.reshape [[RES_16_]]([[RES_25_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_26_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_26_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_44_:%.+]] = memref.reshape [[RES_24_]]([[RES_26_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_8_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__6_:%.+]], [[BLOCK_IN__6_:%.+]] = krnl.block [[LOOP_8_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__6_]]) with ([[LOOP_8_]] -> [[I_13_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_5_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__6_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOOP_7_:%.+]] = vector.load [[VAR_reshape_42_]]{{.}}[[VAR_9_5_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK-DAG: [[VAR_11_4_:%.+]] = krnl.load [[VAR_0_]]{{.}}[[CST_0_]]{{.}} : memref<1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_4_1_1_1_:%.+]] = vector.splat [[VAR_11_4_]] : vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_5_1_1_:%.+]] = arith.addf [[LOOP_7_]], [[LOAD_VAR_reshape_MEM_4_1_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_5_1_1_]], [[VAR_reshape_44_]]{{.}}[[VAR_9_5_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_27_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_28_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_28_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_47_:%.+]] = memref.reshape [[RES_24_]]([[RES_28_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_29_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_29_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_49_:%.+]] = memref.reshape [[RES_27_]]([[RES_29_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_9_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__7_:%.+]], [[BLOCK_IN__7_:%.+]] = krnl.block [[LOOP_9_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__7_]]) with ([[LOOP_9_]] -> [[I_14_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_6_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__7_]]) : (!krnl.loop) -> index +// CHECK: [[LOOP_7_1_:%.+]] = vector.load [[VAR_reshape_47_]]{{.}}[[VAR_9_6_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[VAR_11_5_:%.+]] = math.sqrt [[LOOP_7_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_11_5_]], [[VAR_reshape_49_]]{{.}}[[VAR_9_6_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_30_:%.+]] = memref.alloc() {{.*}}: memref<2x64x1x1xf32> +// CHECK-DAG: [[RES_31_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_31_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_52_:%.+]] = memref.reshape [[RES_27_]]([[RES_31_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK-DAG: [[RES_32_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_128_]], [[RES_32_]][0] : memref<1xindex> +// CHECK: [[VAR_reshape_54_:%.+]] = memref.reshape [[RES_30_]]([[RES_32_]]) : (memref<2x64x1x1xf32>, memref<1xindex>) -> memref<128xf32> +// CHECK: krnl.iterate() with (){ +// CHECK: [[LOOP_10_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__8_:%.+]], [[BLOCK_IN__8_:%.+]] = krnl.block [[LOOP_10_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__8_]]) with ([[LOOP_10_]] -> [[I_15_:%.+]] = 0 to 128){ +// CHECK: [[VAR_9_7_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__8_]]) : (!krnl.loop) -> index +// CHECK: [[LOOP_7_1_:%.+]] = vector.load [[VAR_reshape_52_]]{{.}}[[VAR_9_7_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: [[VAR_11_6_:%.+]] = arith.divf [[VAR_cst_]], [[LOOP_7_1_]] : vector<32xf32> +// CHECK: vector.store [[VAR_11_6_]], [[VAR_reshape_54_]]{{.}}[[VAR_9_7_]]{{.}} : memref<128xf32>, vector<32xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_33_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_34_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_34_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_34_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_34_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_57_:%.+]] = memref.reshape [[RES_20_]]([[RES_34_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_35_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_35_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_35_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_1_]], [[RES_35_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_59_:%.+]] = memref.reshape [[RES_30_]]([[RES_35_]]) : (memref<2x64x1x1xf32>, memref<3xindex>) -> memref<2x64x1xf32> +// CHECK-DAG: [[RES_36_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_36_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_36_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_36_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_61_:%.+]] = memref.reshape [[RES_33_]]([[RES_36_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_11_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_11_]]#0, [[LOOP_11_]]#1) with ([[LOOP_11_]]#0 -> [[I_16_:%.+]] = 0 to 2, [[LOOP_11_]]#1 -> [[I_17_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_3_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_11_]]#0, [[LOOP_11_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_12_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__9_:%.+]], [[BLOCK_IN__9_:%.+]] = krnl.block [[LOOP_12_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__9_]]) with ([[LOOP_12_]] -> [[I_18_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_7_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__9_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_57_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_7_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_:%.+]] = krnl.load [[VAR_reshape_59_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_:%.+]] = vector.splat [[LOAD_VAR_reshape_MEM_5_1_1_]] : vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_7_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_6_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_7_1_]], [[VAR_reshape_61_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_7_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_13_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_13_]]) with ([[LOOP_13_]] -> [[I_19_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_8_:%.+]] = krnl.get_induction_var_value([[LOOP_13_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_57_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_8_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_:%.+]] = krnl.load [[VAR_reshape_59_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[CST_0_]]{{.}} : memref<2x64x1xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_1_]], [[VAR_reshape_61_]]{{.}}[[VAR_8_3_]]#0, [[VAR_8_3_]]#1, [[VAR_11_8_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_37_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_38_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_38_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_38_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_38_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_64_:%.+]] = memref.reshape [[RES_33_]]([[RES_38_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_39_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_93_]], [[RES_39_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_66_:%.+]] = memref.reshape [[PARAM_1_]]([[RES_39_]]) : (memref<31x3xf32>, memref<1xindex>) -> memref<93xf32> +// CHECK-DAG: [[RES_40_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_40_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_40_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_40_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_68_:%.+]] = memref.reshape [[RES_37_]]([[RES_40_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_14_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_14_]]#0, [[LOOP_14_]]#1) with ([[LOOP_14_]]#0 -> [[I_20_:%.+]] = 0 to 2, [[LOOP_14_]]#1 -> [[I_21_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_4_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_14_]]#0, [[LOOP_14_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_15_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__10_:%.+]], [[BLOCK_IN__10_:%.+]] = krnl.block [[LOOP_15_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__10_]]) with ([[LOOP_15_]] -> [[I_22_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_9_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__10_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_64_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_9_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_:%.+]] = vector.load [[VAR_reshape_66_]]{{.}}[[VAR_11_9_]]{{.}} : memref<93xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_6_1_1_1_]], [[VAR_reshape_68_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_9_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_16_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_16_]]) with ([[LOOP_16_]] -> [[I_23_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_10_:%.+]] = krnl.get_induction_var_value([[LOOP_16_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_64_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_10_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_66_]]{{.}}[[VAR_11_10_]]{{.}} : memref<93xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_:%.+]] = arith.mulf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_1_1_]], [[VAR_reshape_68_]]{{.}}[[VAR_8_4_]]#0, [[VAR_8_4_]]#1, [[VAR_11_10_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK-DAG: [[RES_41_:%.+]] = memref.alloc() {{.*}}: memref<2x64x31x3xf32> +// CHECK-DAG: [[RES_42_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_42_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_42_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_42_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_71_:%.+]] = memref.reshape [[RES_37_]]([[RES_42_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[RES_43_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[CST_93_]], [[RES_43_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_73_:%.+]] = memref.reshape [[PARAM_2_]]([[RES_43_]]) : (memref<31x3xf32>, memref<1xindex>) -> memref<93xf32> +// CHECK-DAG: [[RES_44_:%.+]] = memref.alloc() {{.*}}: memref<3xindex> +// CHECK: affine.store [[CST_2_]], [[RES_44_]][0] : memref<3xindex> +// CHECK: affine.store [[CST_64_]], [[RES_44_]][1] : memref<3xindex> +// CHECK: affine.store [[CST_93_]], [[RES_44_]][2] : memref<3xindex> +// CHECK-DAG: [[VAR_reshape_75_:%.+]] = memref.reshape [[RES_41_]]([[RES_44_]]) : (memref<2x64x31x3xf32>, memref<3xindex>) -> memref<2x64x93xf32> +// CHECK-DAG: [[LOOP_17_:%.+]]:2 = krnl.define_loops 2 +// CHECK: krnl.iterate([[LOOP_17_]]#0, [[LOOP_17_]]#1) with ([[LOOP_17_]]#0 -> [[I_24_:%.+]] = 0 to 2, [[LOOP_17_]]#1 -> [[I_25_:%.+]] = 0 to 64){ +// CHECK-DAG: [[VAR_8_5_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_17_]]#0, [[LOOP_17_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOOP_18_:%.+]] = krnl.define_loops 1 +// CHECK: [[BLOCK_TILE__11_:%.+]], [[BLOCK_IN__11_:%.+]] = krnl.block [[LOOP_18_]] 32 : (!krnl.loop) -> (!krnl.loop, !krnl.loop) +// CHECK: krnl.iterate([[BLOCK_TILE__11_]]) with ([[LOOP_18_]] -> [[I_26_:%.+]] = 0 to 62){ +// CHECK: [[VAR_11_11_:%.+]] = krnl.get_induction_var_value([[BLOCK_TILE__11_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_71_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_11_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_1_:%.+]] = vector.load [[VAR_reshape_73_]]{{.}}[[VAR_11_11_]]{{.}} : memref<93xf32>, vector<32xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_1_]] : vector<32xf32> +// CHECK: vector.store [[LOAD_VAR_reshape_MEM_6_1_1_1_1_]], [[VAR_reshape_75_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_11_]]{{.}} : memref<2x64x93xf32>, vector<32xf32> +// CHECK: } +// CHECK: [[LOOP_19_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_19_]]) with ([[LOOP_19_]] -> [[I_27_:%.+]] = 64 to 93){ +// CHECK: [[VAR_11_12_:%.+]] = krnl.get_induction_var_value([[LOOP_19_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_71_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_12_]]{{.}} : memref<2x64x93xf32> +// CHECK-DAG: [[LOAD_VAR_reshape_MEM_5_1_1_1_1_1_:%.+]] = krnl.load [[VAR_reshape_73_]]{{.}}[[VAR_11_12_]]{{.}} : memref<93xf32> +// CHECK: [[LOAD_VAR_reshape_MEM_6_1_1_1_1_:%.+]] = arith.addf [[LOAD_VAR_reshape_MEM_4_1_1_1_1_1_1_]], [[LOAD_VAR_reshape_MEM_5_1_1_1_1_1_]] : f32 +// CHECK: krnl.store [[LOAD_VAR_reshape_MEM_6_1_1_1_1_]], [[VAR_reshape_75_]]{{.}}[[VAR_8_5_]]#0, [[VAR_8_5_]]#1, [[VAR_11_12_]]{{.}} : memref<2x64x93xf32> +// CHECK: } +// CHECK: } +// CHECK: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[RES_41_]] : memref<2x64x31x3xf32> to tensor<2x64x31x3xf32> +// CHECK: onnx.Return [[VAR_7_]] : tensor<2x64x31x3xf32> +// CHECK: } } // -----