From c7308d405d286674fbcd427da3f7a06f52dd70a2 Mon Sep 17 00:00:00 2001 From: David Green Date: Mon, 10 Jun 2024 20:35:33 +0100 Subject: [PATCH] [LSR][AArch64] Optimize chain generation based on legal addressing modes (#94453) LSR will generate chains of related instructions with a known increment between them. With SVE, in the case of the test case, this can include increments like 'vscale * 16 + 8'. The idea of this patch is if we have a '+8' increment already calculated in the chain, we can generate a (legal) '+ vscale*16' addressing mode from it, allowing us to use the '[x16, #1, mul vl]' addressing mode instructions. In order to do this we keep track of the known 'bases' when generating chains in GenerateIVChain, checking for each if the accumulated increment expression from the base neatly folds into a legal addressing mode. If they do not we fall back to the existing LeftOverExpr, whether it is legal or not. This is mostly orthogonal to #88124, dealing with the generation of chains as opposed to rest of LSR. The existing vscale addressing mode work has greatly helped compared to the last time I looked at this, allowing us to check that the addressing modes are indeed legal. --- .../Transforms/Scalar/LoopStrengthReduce.cpp | 72 +++++++++++++--- llvm/test/CodeGen/AArch64/sve-lsrchain.ll | 86 +++++++++---------- 2 files changed, 98 insertions(+), 60 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 73ed611e8de8c4..3a98e257367b25 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -1256,7 +1256,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, - Instruction *Fixup = nullptr); + Instruction *Fixup = nullptr, + int64_t ScalableOffset = 0); static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) { if (isa(Reg) || isa(Reg)) @@ -1675,16 +1676,18 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, - Instruction *Fixup/*= nullptr*/) { + Instruction *Fixup /* = nullptr */, + int64_t ScalableOffset) { switch (Kind) { case LSRUse::Address: return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset, - HasBaseReg, Scale, AccessTy.AddrSpace, Fixup); + HasBaseReg, Scale, AccessTy.AddrSpace, + Fixup, ScalableOffset); case LSRUse::ICmpZero: // There's not even a target hook for querying whether it would be legal to // fold a GV into an ICmp. - if (BaseGV) + if (BaseGV || ScalableOffset != 0) return false; // ICmp only has two operands; don't allow more than two non-trivial parts. @@ -1715,11 +1718,12 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, case LSRUse::Basic: // Only handle single-register values. - return !BaseGV && Scale == 0 && BaseOffset == 0; + return !BaseGV && Scale == 0 && BaseOffset == 0 && ScalableOffset == 0; case LSRUse::Special: // Special case Basic to handle -1 scales. - return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0; + return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0 && + ScalableOffset == 0; } llvm_unreachable("Invalid LSRUse Kind!"); @@ -1843,7 +1847,7 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI, static bool isAlwaysFoldable(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, GlobalValue *BaseGV, int64_t BaseOffset, - bool HasBaseReg) { + bool HasBaseReg, int64_t ScalableOffset = 0) { // Fast-path: zero is always foldable. if (BaseOffset == 0 && !BaseGV) return true; @@ -1859,7 +1863,7 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI, } return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset, - HasBaseReg, Scale); + HasBaseReg, Scale, nullptr, ScalableOffset); } static bool isAlwaysFoldable(const TargetTransformInfo &TTI, @@ -3165,16 +3169,30 @@ void LSRInstance::FinalizeChain(IVChain &Chain) { static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, Value *Operand, const TargetTransformInfo &TTI) { const SCEVConstant *IncConst = dyn_cast(IncExpr); - if (!IncConst || !isAddressUse(TTI, UserInst, Operand)) - return false; + int64_t IncOffset = 0; + int64_t ScalableOffset = 0; + if (IncConst) { + if (IncConst && IncConst->getAPInt().getSignificantBits() > 64) + return false; + IncOffset = IncConst->getValue()->getSExtValue(); + } else { + // Look for mul(vscale, constant), to detect ScalableOffset. + auto *IncVScale = dyn_cast(IncExpr); + if (!IncVScale || IncVScale->getNumOperands() != 2 || + !isa(IncVScale->getOperand(1))) + return false; + auto *Scale = dyn_cast(IncVScale->getOperand(0)); + if (!Scale || Scale->getType()->getScalarSizeInBits() > 64) + return false; + ScalableOffset = Scale->getValue()->getSExtValue(); + } - if (IncConst->getAPInt().getSignificantBits() > 64) + if (!isAddressUse(TTI, UserInst, Operand)) return false; MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand); - int64_t IncOffset = IncConst->getValue()->getSExtValue(); if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr, - IncOffset, /*HasBaseReg=*/false)) + IncOffset, /*HasBaseReg=*/false, ScalableOffset)) return false; return true; @@ -3220,6 +3238,10 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, Type *IVTy = IVSrc->getType(); Type *IntTy = SE.getEffectiveSCEVType(IVTy); const SCEV *LeftOverExpr = nullptr; + const SCEV *Accum = SE.getZero(IntTy); + SmallVector> Bases; + Bases.emplace_back(Accum, IVSrc); + for (const IVInc &Inc : Chain) { Instruction *InsertPt = Inc.UserInst; if (isa(InsertPt)) @@ -3232,10 +3254,31 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, // IncExpr was the result of subtraction of two narrow values, so must // be signed. const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy); + Accum = SE.getAddExpr(Accum, IncExpr); LeftOverExpr = LeftOverExpr ? SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr; } - if (LeftOverExpr && !LeftOverExpr->isZero()) { + + // Look through each base to see if any can produce a nice addressing mode. + bool FoundBase = false; + for (auto [MapScev, MapIVOper] : reverse(Bases)) { + const SCEV *Remainder = SE.getMinusSCEV(Accum, MapScev); + if (canFoldIVIncExpr(Remainder, Inc.UserInst, Inc.IVOperand, TTI)) { + if (!Remainder->isZero()) { + Rewriter.clearPostInc(); + Value *IncV = Rewriter.expandCodeFor(Remainder, IntTy, InsertPt); + const SCEV *IVOperExpr = + SE.getAddExpr(SE.getUnknown(MapIVOper), SE.getUnknown(IncV)); + IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt); + } else { + IVOper = MapIVOper; + } + + FoundBase = true; + break; + } + } + if (!FoundBase && LeftOverExpr && !LeftOverExpr->isZero()) { // Expand the IV increment. Rewriter.clearPostInc(); Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt); @@ -3246,6 +3289,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain, // If an IV increment can't be folded, use it as the next IV value. if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) { assert(IVTy == IVOper->getType() && "inconsistent IV increment type"); + Bases.emplace_back(Accum, IVOper); IVSrc = IVOper; LeftOverExpr = nullptr; } diff --git a/llvm/test/CodeGen/AArch64/sve-lsrchain.ll b/llvm/test/CodeGen/AArch64/sve-lsrchain.ll index 9c7bffb921ce29..1931cfc2ef51de 100644 --- a/llvm/test/CodeGen/AArch64/sve-lsrchain.ll +++ b/llvm/test/CodeGen/AArch64/sve-lsrchain.ll @@ -14,24 +14,22 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float ; CHECK-NEXT: // %bb.2: // %for.body.us.preheader ; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: add x11, x2, x11, lsl #1 -; CHECK-NEXT: mov x12, #-16 // =0xfffffffffffffff0 -; CHECK-NEXT: ptrue p1.b ; CHECK-NEXT: mov w8, wzr +; CHECK-NEXT: ptrue p1.b ; CHECK-NEXT: mov x9, xzr ; CHECK-NEXT: mov w10, wzr -; CHECK-NEXT: addvl x12, x12, #1 -; CHECK-NEXT: mov x13, #4 // =0x4 -; CHECK-NEXT: mov x14, #8 // =0x8 +; CHECK-NEXT: mov x12, #4 // =0x4 +; CHECK-NEXT: mov x13, #8 // =0x8 ; CHECK-NEXT: .LBB0_3: // %for.body.us ; CHECK-NEXT: // =>This Loop Header: Depth=1 ; CHECK-NEXT: // Child Loop BB0_4 Depth 2 -; CHECK-NEXT: add x15, x0, x9, lsl #2 -; CHECK-NEXT: sbfiz x16, x8, #1, #32 -; CHECK-NEXT: mov x17, x2 -; CHECK-NEXT: ldp s0, s1, [x15] -; CHECK-NEXT: add x16, x16, #8 -; CHECK-NEXT: ldp s2, s3, [x15, #8] -; CHECK-NEXT: ubfiz x15, x8, #1, #32 +; CHECK-NEXT: add x14, x0, x9, lsl #2 +; CHECK-NEXT: sbfiz x15, x8, #1, #32 +; CHECK-NEXT: mov x16, x2 +; CHECK-NEXT: ldp s0, s1, [x14] +; CHECK-NEXT: add x15, x15, #8 +; CHECK-NEXT: ldp s2, s3, [x14, #8] +; CHECK-NEXT: ubfiz x14, x8, #1, #32 ; CHECK-NEXT: fcvt h0, s0 ; CHECK-NEXT: fcvt h1, s1 ; CHECK-NEXT: fcvt h2, s2 @@ -43,56 +41,52 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float ; CHECK-NEXT: .LBB0_4: // %for.cond.i.preheader.us ; CHECK-NEXT: // Parent Loop BB0_3 Depth=1 ; CHECK-NEXT: // => This Inner Loop Header: Depth=2 -; CHECK-NEXT: ld1b { z4.b }, p1/z, [x17, x15] -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17] -; CHECK-NEXT: add x18, x17, x16 -; CHECK-NEXT: add x3, x17, x15 +; CHECK-NEXT: ld1b { z4.b }, p1/z, [x16, x14] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16] +; CHECK-NEXT: add x17, x16, x15 +; CHECK-NEXT: add x18, x16, x14 +; CHECK-NEXT: add x3, x17, #8 +; CHECK-NEXT: add x4, x17, #16 ; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h -; CHECK-NEXT: ld1b { z5.b }, p1/z, [x17, x16] +; CHECK-NEXT: ld1b { z5.b }, p1/z, [x16, x15] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x12, lsl #1] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1] -; CHECK-NEXT: add x18, x18, #16 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x13, lsl #1] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl] -; CHECK-NEXT: st1h { z4.h }, p0, [x17] -; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #1, mul vl] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #1, mul vl] +; CHECK-NEXT: st1h { z4.h }, p0, [x16] +; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #1, mul vl] ; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h -; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12] -; CHECK-NEXT: add x18, x18, x12 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #1, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1] -; CHECK-NEXT: add x18, x18, #16 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #1, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl] -; CHECK-NEXT: st1h { z4.h }, p0, [x17, #1, mul vl] -; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #2, mul vl] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #2, mul vl] +; CHECK-NEXT: st1h { z4.h }, p0, [x16, #1, mul vl] +; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #2, mul vl] ; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h -; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12] -; CHECK-NEXT: add x18, x18, x12 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #2, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1] -; CHECK-NEXT: add x18, x18, #16 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #2, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl] -; CHECK-NEXT: st1h { z4.h }, p0, [x17, #2, mul vl] -; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #3, mul vl] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #3, mul vl] +; CHECK-NEXT: st1h { z4.h }, p0, [x16, #2, mul vl] +; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #3, mul vl] ; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h -; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12] -; CHECK-NEXT: add x18, x18, x12 +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #3, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h -; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1] +; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #3, mul vl] ; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h -; CHECK-NEXT: st1h { z4.h }, p0, [x17, #3, mul vl] -; CHECK-NEXT: addvl x17, x17, #4 -; CHECK-NEXT: cmp x17, x11 +; CHECK-NEXT: st1h { z4.h }, p0, [x16, #3, mul vl] +; CHECK-NEXT: addvl x16, x16, #4 +; CHECK-NEXT: cmp x16, x11 ; CHECK-NEXT: b.lo .LBB0_4 ; CHECK-NEXT: // %bb.5: // %while.cond.i..exit_crit_edge.us ; CHECK-NEXT: // in Loop: Header=BB0_3 Depth=1