diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 73ed611e8de8c4..3a98e257367b25 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -1256,7 +1256,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, - Instruction *Fixup = nullptr); + Instruction *Fixup = nullptr, + int64_t ScalableOffset = 0); static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) { if (isa(Reg) || isa(Reg)) @@ -1675,16 +1676,18 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, - Instruction *Fixup/*= nullptr*/) { + Instruction *Fixup /* = nullptr */, + int64_t ScalableOffset) { switch (Kind) { case LSRUse::Address: return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset, - HasBaseReg, Scale, AccessTy.AddrSpace, Fixup); + HasBaseReg, Scale, AccessTy.AddrSpace, + Fixup, ScalableOffset); case LSRUse::ICmpZero: // There's not even a target hook for querying whether it would be legal to // fold a GV into an ICmp. - if (BaseGV) + if (BaseGV || ScalableOffset != 0) return false; // ICmp only has two operands; don't allow more than two non-trivial parts. @@ -1715,11 +1718,12 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, case LSRUse::Basic: // Only handle single-register values. - return !BaseGV && Scale == 0 && BaseOffset == 0; + return !BaseGV && Scale == 0 && BaseOffset == 0 && ScalableOffset == 0; case LSRUse::Special: // Special case Basic to handle -1 scales. - return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0; + return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0 && + ScalableOffset == 0; } llvm_unreachable("Invalid LSRUse Kind!"); @@ -1843,7 +1847,7 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI, static bool isAlwaysFoldable(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, GlobalValue *BaseGV, int64_t BaseOffset, - bool HasBaseReg) { + bool HasBaseReg, int64_t ScalableOffset = 0) { // Fast-path: zero is always foldable. if (BaseOffset == 0 && !BaseGV) return true; @@ -1859,7 +1863,7 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI, } return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset, - HasBaseReg, Scale); + HasBaseReg, Scale, nullptr, ScalableOffset); } static bool isAlwaysFoldable(const TargetTransformInfo &TTI, @@ -3165,16 +3169,30 @@ void LSRInstance::FinalizeChain(IVChain &Chain) { static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, Value *Operand, const TargetTransformInfo &TTI) { const SCEVConstant *IncConst = dyn_cast(IncExpr); - if (!IncConst || !isAddressUse(TTI, UserInst, Operand)) - return false; + int64_t IncOffset = 0; + int64_t ScalableOffset = 0; + if (IncConst) { + if (IncConst && IncConst->getAPInt().getSignificantBits() > 64) + return false; + IncOffset = IncConst->getValue()->getSExtValue(); + } else { + // Look for mul(vscale, constant), to detect ScalableOffset. + auto *IncVScale = dyn_cast(IncExpr); + if (!IncVScale || IncVScale->getNumOperands() != 2 || + !isa(IncVScale->getOperand(1))) + return false; + auto *Scale = dyn_cast(IncVScale->getOperand(0)); + if (!Scale || Scale->getType()->getScalarSizeInBits() > 64) + return false; + ScalableOffset = Scale->getValue()->getSExtValue(); + } - if (IncConst->getAPInt().getSignificantBits() > 64) + if (!isAddressUse(TTI, UserInst, Operand)) return false; MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand); - int64_t IncOffset = IncConst->getValue()->getSExtValue(); if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr, - IncOffset, /*HasBaseReg=*/false)) + IncOffset, /*HasBaseReg=*/false, ScalableOffset)) return false; return true; @@ -3220,6 +3238,10 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, Type *IVTy = IVSrc->getType(); Type *IntTy = SE.getEffectiveSCEVType(IVTy); const SCEV *LeftOverExpr = nullptr; + const SCEV *Accum = SE.getZero(IntTy); + SmallVector> Bases; + Bases.emplace_back(Accum, IVSrc); + for (const IVInc &Inc : Chain) { Instruction *InsertPt = Inc.UserInst; if (isa(InsertPt)) @@ -3232,10 +3254,31 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, // IncExpr was the result of subtraction of two narrow values, so must // be signed. const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy); + Accum = SE.getAddExpr(Accum, IncExpr); LeftOverExpr = LeftOverExpr ? SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr; } - if (LeftOverExpr && !LeftOverExpr->isZero()) { + + // Look through each base to see if any can produce a nice addressing mode. + bool FoundBase = false; + for (auto [MapScev, MapIVOper] : reverse(Bases)) { + const SCEV *Remainder = SE.getMinusSCEV(Accum, MapScev); + if (canFoldIVIncExpr(Remainder, Inc.UserInst, Inc.IVOperand, TTI)) { + if (!Remainder->isZero()) { + Rewriter.clearPostInc(); + Value *IncV = Rewriter.expandCodeFor(Remainder, IntTy, InsertPt); + const SCEV *IVOperExpr = + SE.getAddExpr(SE.getUnknown(MapIVOper), SE.getUnknown(IncV)); + IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt); + } else { + IVOper = MapIVOper; + } + + FoundBase = true; + break; + } + } + if (!FoundBase && LeftOverExpr && !LeftOverExpr->isZero()) { // Expand the IV increment. Rewriter.clearPostInc(); Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt); @@ -3246,6 +3289,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, // If an IV increment can't be folded, use it as the next IV value. if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) { assert(IVTy == IVOper->getType() && "inconsistent IV increment type"); + Bases.emplace_back(Accum, IVOper); IVSrc = IVOper; LeftOverExpr = nullptr; } diff --git a/llvm/test/CodeGen/AArch64/sve-lsrchain.ll b/llvm/test/CodeGen/AArch64/sve-lsrchain.ll index 9c7bffb921ce29..1931cfc2ef51de 100644 --- a/llvm/test/CodeGen/AArch64/sve-lsrchain.ll +++ b/llvm/test/CodeGen/AArch64/sve-lsrchain.ll @@ -14,24 +14,22 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float ; CHECK-NEXT: // %bb.2: // %for.body.us.preheader ; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: add x11, x2, x11, lsl #1 -; CHECK-NEXT: mov x12, #-16 // =0xfffffffffffffff0 -; CHECK-NEXT: ptrue p1.b ; CHECK-NEXT: mov w8, wzr +; CHECK-NEXT: ptrue p1.b ; CHECK-NEXT: mov x9, xzr ; CHECK-NEXT: mov w10, wzr -; CHECK-NEXT: addvl x12, x12, #1 -; CHECK-NEXT: mov x13, #4 // =0x4 -; CHECK-NEXT: mov x14, #8 // =0x8 +; CHECK-NEXT: mov x12, #4 // =0x4 +; CHECK-NEXT: mov x13, #8 // =0x8 ; CHECK-NEXT: .LBB0_3: // %for.body.us ; CHECK-NEXT: // =>This Loop Header: Depth=1 ; CHECK-NEXT: // Child Loop BB0_4 Depth 2 -; CHECK-NEXT: add x15, x0, x9, lsl #2 -; CHECK-NEXT: sbfiz x16, x8, #1, #32 -; CHECK-NEXT: mov x17, x2 -; CHECK-NEXT: ldp s0, s1, [x15] -; CHECK-NEXT: add x16, x16, #8 -; CHECK-NEXT: ldp s2, s3, [x15, #8] -; CHECK-NEXT: ubfiz x15, x8, #1, #32 +; CHECK-NEXT: add x14, x0, x9, lsl #2 +; CHECK-NEXT: sbfiz x15, x8, #1, #32 +; CHECK-NEXT: mov x16, x2 +; CHECK-NEXT: ldp s0, s1, [x14] +; CHECK-NEXT: add x15, x15, #8 +; CHECK-NEXT: ldp s2, s3, [x14, #8] +; CHECK-NEXT: ubfiz x14, x8, #1, #32 ; CHECK-NEXT: fcvt h0, s0 ; CHECK-NEXT: fcvt h1, s1 ; CHECK-NEXT: fcvt h2, s2 @@ -43,56 +41,52 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float ; CHECK-NEXT: .LBB0_4: // %for.cond.i.preheader.us ; CHECK-NEXT: // Parent Loop BB0_3 Depth=1 ; CHECK-NEXT: // => This Inner Loop Header: Depth=2 -; CHECK-NEXT: ld1b { z4.b }, p1/z, [x17, x15] -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17] -; CHECK-NEXT: add x18, x17, x16 -; CHECK-NEXT: add x3, x17, x15 +; CHECK-NEXT: ld1b { z4.b }, p1/z, [x16, x14] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16] +; CHECK-NEXT: add x17, x16, x15 +; CHECK-NEXT: add x18, x16, x14 +; CHECK-NEXT: add x3, x17, #8 +; CHECK-NEXT: add x4, x17, #16 ; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h -; CHECK-NEXT: ld1b { z5.b }, p1/z, [x17, x16] +; CHECK-NEXT: ld1b { z5.b }, p1/z, [x16, x15] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x12, lsl #1] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1] -; CHECK-NEXT: add x18, x18, #16 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x13, lsl #1] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl] -; CHECK-NEXT: st1h { z4.h }, p0, [x17] -; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #1, mul vl] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #1, mul vl] +; CHECK-NEXT: st1h { z4.h }, p0, [x16] +; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #1, mul vl] ; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h -; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12] -; CHECK-NEXT: add x18, x18, x12 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #1, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1] -; CHECK-NEXT: add x18, x18, #16 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #1, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl] -; CHECK-NEXT: st1h { z4.h }, p0, [x17, #1, mul vl] -; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #2, mul vl] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #2, mul vl] +; CHECK-NEXT: st1h { z4.h }, p0, [x16, #1, mul vl] +; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #2, mul vl] ; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h -; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12] -; CHECK-NEXT: add x18, x18, x12 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #2, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1] -; CHECK-NEXT: add x18, x18, #16 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #2, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl] -; CHECK-NEXT: st1h { z4.h }, p0, [x17, #2, mul vl] -; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #3, mul vl] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #3, mul vl] +; CHECK-NEXT: st1h { z4.h }, p0, [x16, #2, mul vl] +; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #3, mul vl] ; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h -; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12] -; CHECK-NEXT: add x18, x18, x12 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #3, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #3, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h -; CHECK-NEXT: st1h { z4.h }, p0, [x17, #3, mul vl] -; CHECK-NEXT: addvl x17, x17, #4 -; CHECK-NEXT: cmp x17, x11 +; CHECK-NEXT: st1h { z4.h }, p0, [x16, #3, mul vl] +; CHECK-NEXT: addvl x16, x16, #4 +; CHECK-NEXT: cmp x16, x11 ; CHECK-NEXT: b.lo .LBB0_4 ; CHECK-NEXT: // %bb.5: // %while.cond.i..exit_crit_edge.us ; CHECK-NEXT: // in Loop: Header=BB0_3 Depth=1