diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index dde90abc06cd9..b2124c6106198 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -739,11 +739,6 @@ class TargetTransformInfo { /// cost should return false, otherwise return true. bool isNumRegsMajorCostOfLSR() const; - /// Return true if LSR should attempts to replace a use of an otherwise dead - /// primary IV in the latch condition with another IV available in the loop. - /// When successful, makes the primary IV dead. - bool shouldFoldTerminatingConditionAfterLSR() const; - /// Return true if LSR should drop a found solution if it's calculated to be /// less profitable than the baseline. bool shouldDropLSRSolutionIfLessProfitable() const; @@ -1888,7 +1883,6 @@ class TargetTransformInfo::Concept { virtual bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1, const TargetTransformInfo::LSRCost &C2) = 0; virtual bool isNumRegsMajorCostOfLSR() = 0; - virtual bool shouldFoldTerminatingConditionAfterLSR() const = 0; virtual bool shouldDropLSRSolutionIfLessProfitable() const = 0; virtual bool isProfitableLSRChainElement(Instruction *I) = 0; virtual bool canMacroFuseCmp() = 0; @@ -2367,9 +2361,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { bool isNumRegsMajorCostOfLSR() override { return Impl.isNumRegsMajorCostOfLSR(); } - bool shouldFoldTerminatingConditionAfterLSR() const override { - return Impl.shouldFoldTerminatingConditionAfterLSR(); - } bool shouldDropLSRSolutionIfLessProfitable() const override { return Impl.shouldDropLSRSolutionIfLessProfitable(); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index d208a710bb27f..11b07ac0b7fc4 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -244,8 +244,6 @@ class TargetTransformInfoImplBase { bool isNumRegsMajorCostOfLSR() const { return true; } - bool shouldFoldTerminatingConditionAfterLSR() const { return false; } - bool shouldDropLSRSolutionIfLessProfitable() const { return false; } bool isProfitableLSRChainElement(Instruction *I) const { return false; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 77ddc10e8a0e7..217e3f1324f9c 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -394,11 +394,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { return TargetTransformInfoImplBase::isNumRegsMajorCostOfLSR(); } - bool shouldFoldTerminatingConditionAfterLSR() const { - return TargetTransformInfoImplBase:: - shouldFoldTerminatingConditionAfterLSR(); - } - bool shouldDropLSRSolutionIfLessProfitable() const { return TargetTransformInfoImplBase::shouldDropLSRSolutionIfLessProfitable(); } diff --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h index d00e0bed91a45..2f5951e3ec3bc 100644 --- a/llvm/include/llvm/CodeGen/TargetPassConfig.h +++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h @@ -140,6 +140,9 @@ class TargetPassConfig : public ImmutablePass { /// callers. bool RequireCodeGenSCCOrder = false; + /// Enable LoopTermFold immediately after LSR + bool EnableLoopTermFold = false; + /// Add the actual instruction selection passes. This does not include /// preparation passes on IR. bool addCoreISelPasses(); diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h index a4ac314bb590e..cc5e93c58f564 100644 --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -169,6 +169,7 @@ void initializeLoopInfoWrapperPassPass(PassRegistry &); void initializeLoopPassPass(PassRegistry &); void initializeLoopSimplifyPass(PassRegistry &); void initializeLoopStrengthReducePass(PassRegistry &); +void initializeLoopTermFoldPass(PassRegistry &); void initializeLoopUnrollPass(PassRegistry &); void initializeLowerAtomicLegacyPassPass(PassRegistry &); void initializeLowerConstantIntrinsicsPass(PassRegistry &); diff --git a/llvm/include/llvm/LinkAllPasses.h b/llvm/include/llvm/LinkAllPasses.h index e6a70dfd1ea6f..1da02153d846f 100644 --- a/llvm/include/llvm/LinkAllPasses.h +++ b/llvm/include/llvm/LinkAllPasses.h @@ -90,6 +90,7 @@ struct ForcePassLinking { (void)llvm::createLoopExtractorPass(); (void)llvm::createLoopSimplifyPass(); (void)llvm::createLoopStrengthReducePass(); + (void)llvm::createLoopTermFoldPass(); (void)llvm::createLoopUnrollPass(); (void)llvm::createLowerGlobalDtorsLegacyPass(); (void)llvm::createLowerInvokePass(); diff --git a/llvm/include/llvm/Passes/MachinePassRegistry.def b/llvm/include/llvm/Passes/MachinePassRegistry.def index 8e669ee579123..05baf514fa721 100644 --- a/llvm/include/llvm/Passes/MachinePassRegistry.def +++ b/llvm/include/llvm/Passes/MachinePassRegistry.def @@ -79,6 +79,7 @@ FUNCTION_PASS("win-eh-prepare", WinEHPreparePass()) #define LOOP_PASS(NAME, CREATE_PASS) #endif LOOP_PASS("loop-reduce", LoopStrengthReducePass()) +LOOP_PASS("loop-term-fold", LoopTermFoldPass()) #undef LOOP_PASS #ifndef MACHINE_MODULE_PASS diff --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h index 98d0adca35521..17f4327eb3e1a 100644 --- a/llvm/include/llvm/Transforms/Scalar.h +++ b/llvm/include/llvm/Transforms/Scalar.h @@ -51,6 +51,14 @@ Pass *createLICMPass(); // Pass *createLoopStrengthReducePass(); +//===----------------------------------------------------------------------===// +// +// LoopTermFold - This pass attempts to eliminate the last use of an IV in +// a loop terminator instruction by rewriting it in terms of another IV. +// Expected to be run immediately after LSR. +// +Pass *createLoopTermFoldPass(); + //===----------------------------------------------------------------------===// // // LoopUnroll - This pass is a simple loop unrolling pass. diff --git a/llvm/include/llvm/Transforms/Scalar/LoopTermFold.h b/llvm/include/llvm/Transforms/Scalar/LoopTermFold.h new file mode 100644 index 0000000000000..974024c586aa8 --- /dev/null +++ b/llvm/include/llvm/Transforms/Scalar/LoopTermFold.h @@ -0,0 +1,30 @@ +//===- LoopTermFold.h - Loop Term Fold Pass ---------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_SCALAR_LOOPTERMFOLD_H +#define LLVM_TRANSFORMS_SCALAR_LOOPTERMFOLD_H + +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class Loop; +class LPMUpdater; + +class LoopTermFoldPass : public PassInfoMixin { +public: + PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &U); +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_SCALAR_LOOPTERMFOLD_H diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index dcde78925bfa9..2c26493bd3f1c 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -427,10 +427,6 @@ bool TargetTransformInfo::isNumRegsMajorCostOfLSR() const { return TTIImpl->isNumRegsMajorCostOfLSR(); } -bool TargetTransformInfo::shouldFoldTerminatingConditionAfterLSR() const { - return TTIImpl->shouldFoldTerminatingConditionAfterLSR(); -} - bool TargetTransformInfo::shouldDropLSRSolutionIfLessProfitable() const { return TTIImpl->shouldDropLSRSolutionIfLessProfitable(); } diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp index 1b0012b65b80d..1d52ebe6717f0 100644 --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -828,6 +828,8 @@ void TargetPassConfig::addIRPasses() { if (!DisableLSR) { addPass(createCanonicalizeFreezeInLoopsPass()); addPass(createLoopStrengthReducePass()); + if (EnableLoopTermFold) + addPass(createLoopTermFoldPass()); if (PrintLSR) addPass(createPrintFunctionPass(dbgs(), "\n\n*** Code after LSR ***\n")); diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp index 3200767282b22..17eed97fd950c 100644 --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -249,6 +249,7 @@ #include "llvm/Transforms/Scalar/LoopSimplifyCFG.h" #include "llvm/Transforms/Scalar/LoopSink.h" #include "llvm/Transforms/Scalar/LoopStrengthReduce.h" +#include "llvm/Transforms/Scalar/LoopTermFold.h" #include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h" #include "llvm/Transforms/Scalar/LoopUnrollPass.h" #include "llvm/Transforms/Scalar/LoopVersioningLICM.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def index a11fc3755494a..6b5e1cf83c469 100644 --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -646,6 +646,7 @@ LOOP_PASS("loop-idiom-vectorize", LoopIdiomVectorizePass()) LOOP_PASS("loop-instsimplify", LoopInstSimplifyPass()) LOOP_PASS("loop-predication", LoopPredicationPass()) LOOP_PASS("loop-reduce", LoopStrengthReducePass()) +LOOP_PASS("loop-term-fold", LoopTermFoldPass()) LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass()) LOOP_PASS("loop-unroll-full", LoopFullUnrollPass()) LOOP_PASS("loop-versioning-licm", LoopVersioningLICMPass()) diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp index b6884321f0841..794df2212dfa5 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp @@ -336,6 +336,7 @@ class RISCVPassConfig : public TargetPassConfig { if (TM.getOptLevel() != CodeGenOptLevel::None) substitutePass(&PostRASchedulerID, &PostMachineSchedulerID); setEnableSinkAndFold(EnableSinkFold); + EnableLoopTermFold = true; } RISCVTargetMachine &getRISCVTargetMachine() const { diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index f5eca2839acd0..cc69e1d118b5a 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -394,9 +394,6 @@ class RISCVTTIImpl : public BasicTTIImplBase { bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1, const TargetTransformInfo::LSRCost &C2); - bool shouldFoldTerminatingConditionAfterLSR() const { - return true; - } bool shouldConsiderAddressTypePromotion(const Instruction &I, bool &AllowPromotionWithoutCommonHeader); diff --git a/llvm/lib/Transforms/Scalar/CMakeLists.txt b/llvm/lib/Transforms/Scalar/CMakeLists.txt index ba09ebf8b04c4..939a145723956 100644 --- a/llvm/lib/Transforms/Scalar/CMakeLists.txt +++ b/llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -44,6 +44,7 @@ add_llvm_component_library(LLVMScalarOpts LoopRotation.cpp LoopSimplifyCFG.cpp LoopStrengthReduce.cpp + LoopTermFold.cpp LoopUnrollPass.cpp LoopUnrollAndJamPass.cpp LoopVersioningLICM.cpp diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 91461d1ed2759..a62b87fe2a53d 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -189,10 +189,6 @@ static cl::opt SetupCostDepthLimit( "lsr-setupcost-depth-limit", cl::Hidden, cl::init(7), cl::desc("The limit on recursion depth for LSRs setup cost")); -static cl::opt AllowTerminatingConditionFoldingAfterLSR( - "lsr-term-fold", cl::Hidden, - cl::desc("Attempt to replace primary IV with other IV.")); - static cl::opt AllowDropSolutionIfLessProfitable( "lsr-drop-solution", cl::Hidden, cl::desc("Attempt to drop solution if it is less profitable")); @@ -205,9 +201,6 @@ static cl::opt DropScaledForVScale( "lsr-drop-scaled-reg-for-vscale", cl::Hidden, cl::init(true), cl::desc("Avoid using scaled registers with vscale-relative addressing")); -STATISTIC(NumTermFold, - "Number of terminating condition fold recognized and performed"); - #ifndef NDEBUG // Stress test IV chain generation. static cl::opt StressIVChain( @@ -7062,186 +7055,6 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE, return nullptr; } -static std::optional> -canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, - const LoopInfo &LI, const TargetTransformInfo &TTI) { - if (!L->isInnermost()) { - LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n"); - return std::nullopt; - } - // Only inspect on simple loop structure - if (!L->isLoopSimplifyForm()) { - LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n"); - return std::nullopt; - } - - if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { - LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n"); - return std::nullopt; - } - - BasicBlock *LoopLatch = L->getLoopLatch(); - BranchInst *BI = dyn_cast(LoopLatch->getTerminator()); - if (!BI || BI->isUnconditional()) - return std::nullopt; - auto *TermCond = dyn_cast(BI->getCondition()); - if (!TermCond) { - LLVM_DEBUG( - dbgs() << "Cannot fold on branching condition that is not an ICmpInst"); - return std::nullopt; - } - if (!TermCond->hasOneUse()) { - LLVM_DEBUG( - dbgs() - << "Cannot replace terminating condition with more than one use\n"); - return std::nullopt; - } - - BinaryOperator *LHS = dyn_cast(TermCond->getOperand(0)); - Value *RHS = TermCond->getOperand(1); - if (!LHS || !L->isLoopInvariant(RHS)) - // We could pattern match the inverse form of the icmp, but that is - // non-canonical, and this pass is running *very* late in the pipeline. - return std::nullopt; - - // Find the IV used by the current exit condition. - PHINode *ToFold; - Value *ToFoldStart, *ToFoldStep; - if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) - return std::nullopt; - - // Ensure the simple recurrence is a part of the current loop. - if (ToFold->getParent() != L->getHeader()) - return std::nullopt; - - // If that IV isn't dead after we rewrite the exit condition in terms of - // another IV, there's no point in doing the transform. - if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) - return std::nullopt; - - // Inserting instructions in the preheader has a runtime cost, scale - // the allowed cost with the loops trip count as best we can. - const unsigned ExpansionBudget = [&]() { - unsigned Budget = 2 * SCEVCheapExpansionBudget; - if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L)) - return std::min(Budget, SmallTC); - if (std::optional SmallTC = getLoopEstimatedTripCount(L)) - return std::min(Budget, *SmallTC); - // Unknown trip count, assume long running by default. - return Budget; - }(); - - const SCEV *BECount = SE.getBackedgeTakenCount(L); - const DataLayout &DL = L->getHeader()->getDataLayout(); - SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); - - PHINode *ToHelpFold = nullptr; - const SCEV *TermValueS = nullptr; - bool MustDropPoison = false; - auto InsertPt = L->getLoopPreheader()->getTerminator(); - for (PHINode &PN : L->getHeader()->phis()) { - if (ToFold == &PN) - continue; - - if (!SE.isSCEVable(PN.getType())) { - LLVM_DEBUG(dbgs() << "IV of phi '" << PN - << "' is not SCEV-able, not qualified for the " - "terminating condition folding.\n"); - continue; - } - const SCEVAddRecExpr *AddRec = dyn_cast(SE.getSCEV(&PN)); - // Only speculate on affine AddRec - if (!AddRec || !AddRec->isAffine()) { - LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN - << "' is not an affine add recursion, not qualified " - "for the terminating condition folding.\n"); - continue; - } - - // Check that we can compute the value of AddRec on the exiting iteration - // without soundness problems. evaluateAtIteration internally needs - // to multiply the stride of the iteration number - which may wrap around. - // The issue here is subtle because computing the result accounting for - // wrap is insufficient. In order to use the result in an exit test, we - // must also know that AddRec doesn't take the same value on any previous - // iteration. The simplest case to consider is a candidate IV which is - // narrower than the trip count (and thus original IV), but this can - // also happen due to non-unit strides on the candidate IVs. - if (!AddRec->hasNoSelfWrap() || - !SE.isKnownNonZero(AddRec->getStepRecurrence(SE))) - continue; - - const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE); - const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE); - if (!Expander.isSafeToExpand(TermValueSLocal)) { - LLVM_DEBUG( - dbgs() << "Is not safe to expand terminating value for phi node" << PN - << "\n"); - continue; - } - - if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, - &TTI, InsertPt)) { - LLVM_DEBUG( - dbgs() << "Is too expensive to expand terminating value for phi node" - << PN << "\n"); - continue; - } - - // The candidate IV may have been otherwise dead and poison from the - // very first iteration. If we can't disprove that, we can't use the IV. - if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) { - LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " - << PN << "\n"); - continue; - } - - // The candidate IV may become poison on the last iteration. If this - // value is not branched on, this is a well defined program. We're - // about to add a new use to this IV, and we have to ensure we don't - // insert UB which didn't previously exist. - bool MustDropPoisonLocal = false; - Instruction *PostIncV = - cast(PN.getIncomingValueForBlock(LoopLatch)); - if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(), - &DT)) { - LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" - << PN << "\n"); - - // If this is a complex recurrance with multiple instructions computing - // the backedge value, we might need to strip poison flags from all of - // them. - if (PostIncV->getOperand(0) != &PN) - continue; - - // In order to perform the transform, we need to drop the poison generating - // flags on this instruction (if any). - MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags(); - } - - // We pick the last legal alternate IV. We could expore choosing an optimal - // alternate IV if we had a decent heuristic to do so. - ToHelpFold = &PN; - TermValueS = TermValueSLocal; - MustDropPoison = MustDropPoisonLocal; - } - - LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() - << "Cannot find other AddRec IV to help folding\n";); - - LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() - << "\nFound loop that can fold terminating condition\n" - << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" - << " TermCond: " << *TermCond << "\n" - << " BrandInst: " << *BI << "\n" - << " ToFold: " << *ToFold << "\n" - << " ToHelpFold: " << *ToHelpFold << "\n"); - - if (!ToFold || !ToHelpFold) - return std::nullopt; - return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison); -} - static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, @@ -7302,81 +7115,6 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, } } - const bool EnableFormTerm = [&] { - switch (AllowTerminatingConditionFoldingAfterLSR) { - case cl::BOU_TRUE: - return true; - case cl::BOU_FALSE: - return false; - case cl::BOU_UNSET: - return TTI.shouldFoldTerminatingConditionAfterLSR(); - } - llvm_unreachable("Unhandled cl::boolOrDefault enum"); - }(); - - if (EnableFormTerm) { - if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI)) { - auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; - - Changed = true; - NumTermFold++; - - BasicBlock *LoopPreheader = L->getLoopPreheader(); - BasicBlock *LoopLatch = L->getLoopLatch(); - - (void)ToFold; - LLVM_DEBUG(dbgs() << "To fold phi-node:\n" - << *ToFold << "\n" - << "New term-cond phi-node:\n" - << *ToHelpFold << "\n"); - - Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader); - (void)StartValue; - Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); - - // See comment in canFoldTermCondOfLoop on why this is sufficient. - if (MustDrop) - cast(LoopValue)->dropPoisonGeneratingFlags(); - - // SCEVExpander for both use in preheader and latch - const DataLayout &DL = L->getHeader()->getDataLayout(); - SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); - - assert(Expander.isSafeToExpand(TermValueS) && - "Terminating value was checked safe in canFoldTerminatingCondition"); - - // Create new terminating value at loop preheader - Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(), - LoopPreheader->getTerminator()); - - LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" - << *StartValue << "\n" - << "Terminating value of new term-cond phi-node:\n" - << *TermValue << "\n"); - - // Create new terminating condition at loop latch - BranchInst *BI = cast(LoopLatch->getTerminator()); - ICmpInst *OldTermCond = cast(BI->getCondition()); - IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); - Value *NewTermCond = - LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue, - "lsr_fold_term_cond.replaced_term_cond"); - // Swap successors to exit loop body if IV equals to new TermValue - if (BI->getSuccessor(0) == L->getHeader()) - BI->swapSuccessors(); - - LLVM_DEBUG(dbgs() << "Old term-cond:\n" - << *OldTermCond << "\n" - << "New term-cond:\n" << *NewTermCond << "\n"); - - BI->setCondition(NewTermCond); - - Expander.clear(); - OldTermCond->eraseFromParent(); - DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); - } - } - if (SalvageableDVIRecords.empty()) return Changed; diff --git a/llvm/lib/Transforms/Scalar/LoopTermFold.cpp b/llvm/lib/Transforms/Scalar/LoopTermFold.cpp new file mode 100644 index 0000000000000..12ef367adc43e --- /dev/null +++ b/llvm/lib/Transforms/Scalar/LoopTermFold.cpp @@ -0,0 +1,379 @@ +//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopTermFold.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" +#include +#include +#include + +using namespace llvm; + +#define DEBUG_TYPE "loop-term-fold" + +STATISTIC(NumTermFold, + "Number of terminating condition fold recognized and performed"); + +static std::optional> +canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, + const LoopInfo &LI, const TargetTransformInfo &TTI) { + if (!L->isInnermost()) { + LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n"); + return std::nullopt; + } + // Only inspect on simple loop structure + if (!L->isLoopSimplifyForm()) { + LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n"); + return std::nullopt; + } + + if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { + LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n"); + return std::nullopt; + } + + BasicBlock *LoopLatch = L->getLoopLatch(); + BranchInst *BI = dyn_cast(LoopLatch->getTerminator()); + if (!BI || BI->isUnconditional()) + return std::nullopt; + auto *TermCond = dyn_cast(BI->getCondition()); + if (!TermCond) { + LLVM_DEBUG( + dbgs() << "Cannot fold on branching condition that is not an ICmpInst"); + return std::nullopt; + } + if (!TermCond->hasOneUse()) { + LLVM_DEBUG( + dbgs() + << "Cannot replace terminating condition with more than one use\n"); + return std::nullopt; + } + + BinaryOperator *LHS = dyn_cast(TermCond->getOperand(0)); + Value *RHS = TermCond->getOperand(1); + if (!LHS || !L->isLoopInvariant(RHS)) + // We could pattern match the inverse form of the icmp, but that is + // non-canonical, and this pass is running *very* late in the pipeline. + return std::nullopt; + + // Find the IV used by the current exit condition. + PHINode *ToFold; + Value *ToFoldStart, *ToFoldStep; + if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) + return std::nullopt; + + // Ensure the simple recurrence is a part of the current loop. + if (ToFold->getParent() != L->getHeader()) + return std::nullopt; + + // If that IV isn't dead after we rewrite the exit condition in terms of + // another IV, there's no point in doing the transform. + if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) + return std::nullopt; + + // Inserting instructions in the preheader has a runtime cost, scale + // the allowed cost with the loops trip count as best we can. + const unsigned ExpansionBudget = [&]() { + unsigned Budget = 2 * SCEVCheapExpansionBudget; + if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L)) + return std::min(Budget, SmallTC); + if (std::optional SmallTC = getLoopEstimatedTripCount(L)) + return std::min(Budget, *SmallTC); + // Unknown trip count, assume long running by default. + return Budget; + }(); + + const SCEV *BECount = SE.getBackedgeTakenCount(L); + const DataLayout &DL = L->getHeader()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); + + PHINode *ToHelpFold = nullptr; + const SCEV *TermValueS = nullptr; + bool MustDropPoison = false; + auto InsertPt = L->getLoopPreheader()->getTerminator(); + for (PHINode &PN : L->getHeader()->phis()) { + if (ToFold == &PN) + continue; + + if (!SE.isSCEVable(PN.getType())) { + LLVM_DEBUG(dbgs() << "IV of phi '" << PN + << "' is not SCEV-able, not qualified for the " + "terminating condition folding.\n"); + continue; + } + const SCEVAddRecExpr *AddRec = dyn_cast(SE.getSCEV(&PN)); + // Only speculate on affine AddRec + if (!AddRec || !AddRec->isAffine()) { + LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN + << "' is not an affine add recursion, not qualified " + "for the terminating condition folding.\n"); + continue; + } + + // Check that we can compute the value of AddRec on the exiting iteration + // without soundness problems. evaluateAtIteration internally needs + // to multiply the stride of the iteration number - which may wrap around. + // The issue here is subtle because computing the result accounting for + // wrap is insufficient. In order to use the result in an exit test, we + // must also know that AddRec doesn't take the same value on any previous + // iteration. The simplest case to consider is a candidate IV which is + // narrower than the trip count (and thus original IV), but this can + // also happen due to non-unit strides on the candidate IVs. + if (!AddRec->hasNoSelfWrap() || + !SE.isKnownNonZero(AddRec->getStepRecurrence(SE))) + continue; + + const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE); + const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE); + if (!Expander.isSafeToExpand(TermValueSLocal)) { + LLVM_DEBUG( + dbgs() << "Is not safe to expand terminating value for phi node" << PN + << "\n"); + continue; + } + + if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI, + InsertPt)) { + LLVM_DEBUG( + dbgs() << "Is too expensive to expand terminating value for phi node" + << PN << "\n"); + continue; + } + + // The candidate IV may have been otherwise dead and poison from the + // very first iteration. If we can't disprove that, we can't use the IV. + if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) { + LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n"); + continue; + } + + // The candidate IV may become poison on the last iteration. If this + // value is not branched on, this is a well defined program. We're + // about to add a new use to this IV, and we have to ensure we don't + // insert UB which didn't previously exist. + bool MustDropPoisonLocal = false; + Instruction *PostIncV = + cast(PN.getIncomingValueForBlock(LoopLatch)); + if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(), + &DT)) { + LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN + << "\n"); + + // If this is a complex recurrance with multiple instructions computing + // the backedge value, we might need to strip poison flags from all of + // them. + if (PostIncV->getOperand(0) != &PN) + continue; + + // In order to perform the transform, we need to drop the poison + // generating flags on this instruction (if any). + MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags(); + } + + // We pick the last legal alternate IV. We could expore choosing an optimal + // alternate IV if we had a decent heuristic to do so. + ToHelpFold = &PN; + TermValueS = TermValueSLocal; + MustDropPoison = MustDropPoisonLocal; + } + + LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() + << "Cannot find other AddRec IV to help folding\n";); + + LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() + << "\nFound loop that can fold terminating condition\n" + << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" + << " TermCond: " << *TermCond << "\n" + << " BrandInst: " << *BI << "\n" + << " ToFold: " << *ToFold << "\n" + << " ToHelpFold: " << *ToHelpFold << "\n"); + + if (!ToFold || !ToHelpFold) + return std::nullopt; + return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison); +} + +static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT, + LoopInfo &LI, const TargetTransformInfo &TTI, + TargetLibraryInfo &TLI, MemorySSA *MSSA) { + std::unique_ptr MSSAU; + if (MSSA) + MSSAU = std::make_unique(MSSA); + + auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI); + if (!Opt) + return false; + + auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; + + NumTermFold++; + + BasicBlock *LoopPreheader = L->getLoopPreheader(); + BasicBlock *LoopLatch = L->getLoopLatch(); + + (void)ToFold; + LLVM_DEBUG(dbgs() << "To fold phi-node:\n" + << *ToFold << "\n" + << "New term-cond phi-node:\n" + << *ToHelpFold << "\n"); + + Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader); + (void)StartValue; + Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); + + // See comment in canFoldTermCondOfLoop on why this is sufficient. + if (MustDrop) + cast(LoopValue)->dropPoisonGeneratingFlags(); + + // SCEVExpander for both use in preheader and latch + const DataLayout &DL = L->getHeader()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); + + assert(Expander.isSafeToExpand(TermValueS) && + "Terminating value was checked safe in canFoldTerminatingCondition"); + + // Create new terminating value at loop preheader + Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(), + LoopPreheader->getTerminator()); + + LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" + << *StartValue << "\n" + << "Terminating value of new term-cond phi-node:\n" + << *TermValue << "\n"); + + // Create new terminating condition at loop latch + BranchInst *BI = cast(LoopLatch->getTerminator()); + ICmpInst *OldTermCond = cast(BI->getCondition()); + IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); + Value *NewTermCond = + LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue, + "lsr_fold_term_cond.replaced_term_cond"); + // Swap successors to exit loop body if IV equals to new TermValue + if (BI->getSuccessor(0) == L->getHeader()) + BI->swapSuccessors(); + + LLVM_DEBUG(dbgs() << "Old term-cond:\n" + << *OldTermCond << "\n" + << "New term-cond:\n" + << *NewTermCond << "\n"); + + BI->setCondition(NewTermCond); + + Expander.clear(); + OldTermCond->eraseFromParent(); + DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); + return true; +} + +namespace { + +class LoopTermFold : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + + LoopTermFold(); + +private: + bool runOnLoop(Loop *L, LPPassManager &LPM) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; +}; + +} // end anonymous namespace + +LoopTermFold::LoopTermFold() : LoopPass(ID) { + initializeLoopTermFoldPass(*PassRegistry::getPassRegistry()); +} + +void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.addPreserved(); + AU.addPreservedID(LoopSimplifyID); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); +} + +bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { + if (skipLoop(L)) + return false; + + auto &SE = getAnalysis().getSE(); + auto &DT = getAnalysis().getDomTree(); + auto &LI = getAnalysis().getLoopInfo(); + const auto &TTI = getAnalysis().getTTI( + *L->getHeader()->getParent()); + auto &TLI = getAnalysis().getTLI( + *L->getHeader()->getParent()); + auto *MSSAAnalysis = getAnalysisIfAvailable(); + MemorySSA *MSSA = nullptr; + if (MSSAAnalysis) + MSSA = &MSSAAnalysis->getMSSA(); + return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA); +} + +PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA)) + return PreservedAnalyses::all(); + + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve(); + return PA; +} + +char LoopTermFold::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", + false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding", + false, false) + +Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); } diff --git a/llvm/lib/Transforms/Scalar/Scalar.cpp b/llvm/lib/Transforms/Scalar/Scalar.cpp index 86669e8c5aa49..7aeee1d31f7e7 100644 --- a/llvm/lib/Transforms/Scalar/Scalar.cpp +++ b/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -30,6 +30,7 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLegacyLICMPassPass(Registry); initializeLoopDataPrefetchLegacyPassPass(Registry); initializeLoopStrengthReducePass(Registry); + initializeLoopTermFoldPass(Registry); initializeLoopUnrollPass(Registry); initializeLowerAtomicLegacyPassPass(Registry); initializeMergeICmpsLegacyPassPass(Registry); diff --git a/llvm/test/CodeGen/RISCV/O3-pipeline.ll b/llvm/test/CodeGen/RISCV/O3-pipeline.ll index df9cb5de5d768..44c270fdc3c25 100644 --- a/llvm/test/CodeGen/RISCV/O3-pipeline.ll +++ b/llvm/test/CodeGen/RISCV/O3-pipeline.ll @@ -45,6 +45,7 @@ ; CHECK-NEXT: Canonicalize Freeze Instructions in Loops ; CHECK-NEXT: Induction Variable Users ; CHECK-NEXT: Loop Strength Reduction +; CHECK-NEXT: Loop Terminator Folding ; CHECK-NEXT: Basic Alias Analysis (stateless AA impl) ; CHECK-NEXT: Function Alias Analysis Results ; CHECK-NEXT: Merge contiguous icmps into a memcmp diff --git a/llvm/test/Transforms/LoopStrengthReduce/RISCV/lsr-cost-compare.ll b/llvm/test/Transforms/LoopStrengthReduce/RISCV/lsr-cost-compare.ll index 9c11bd064ad47..cadee94ff4096 100644 --- a/llvm/test/Transforms/LoopStrengthReduce/RISCV/lsr-cost-compare.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/RISCV/lsr-cost-compare.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -loop-reduce -S | FileCheck %s +; RUN: opt < %s -passes=loop-reduce,loop-term-fold -S | FileCheck %s target datalayout = "e-m:e-p:64:64-i64:64-i128:128-n64-S128" target triple = "riscv64" diff --git a/llvm/test/Transforms/LoopStrengthReduce/RISCV/term-fold-crash.ll b/llvm/test/Transforms/LoopStrengthReduce/RISCV/term-fold-crash.ll index 8ca7f0010bbbe..9fb240684d232 100644 --- a/llvm/test/Transforms/LoopStrengthReduce/RISCV/term-fold-crash.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/RISCV/term-fold-crash.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4 -; RUN: opt -S -passes=loop-reduce -mtriple=riscv64-unknown-linux-gnu < %s | FileCheck %s +; RUN: opt -S -passes=loop-reduce,loop-term-fold -mtriple=riscv64-unknown-linux-gnu < %s | FileCheck %s define void @test(ptr %p, i8 %arg, i32 %start) { ; CHECK-LABEL: define void @test( diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll index 2d3d3a4b72a1a..89ddba3343ffa 100644 --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 ; REQUIRES: asserts -; RUN: opt < %s -passes="loop-reduce" -S -debug -lsr-term-fold 2>&1 | FileCheck %s +; RUN: opt < %s -passes=loop-reduce,loop-term-fold -S -debug 2>&1 | FileCheck %s target datalayout = "e-p:64:64:64-n64" diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll index 7299a014b7983..6f34dc843ae1e 100644 --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -passes="loop-reduce" -S -lsr-term-fold | FileCheck %s +; RUN: opt < %s -passes="loop-reduce,loop-term-fold" -S | FileCheck %s target datalayout = "e-p:64:64:64-n64" diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll index 1454535b52bcc..67a71496e4cec 100644 --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-unreachable-bb-phi-node.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -loop-reduce -S -lsr-term-fold | FileCheck %s +; RUN: opt < %s -passes=loop-reduce,loop-term-fold -S | FileCheck %s ; This test used to crash due to matchSimpleRecurrence matching the simple ; recurrence in pn-loop when evaluating unrelated-loop. Since unrelated-loop @@ -13,9 +13,10 @@ define void @phi_node_different_bb() { ; CHECK-NEXT: [[TMP3:%.*]] = icmp ugt i32 [[TMP2]], 1 ; CHECK-NEXT: br i1 [[TMP3]], label [[PN_LOOP]], label [[UNRELATED_LOOP_PREHEADER:%.*]] ; CHECK: unrelated-loop.preheader: +; CHECK-NEXT: [[DOTLCSSA:%.*]] = phi i32 [ [[TMP2]], [[PN_LOOP]] ] ; CHECK-NEXT: br label [[UNRELATED_LOOP:%.*]] ; CHECK: unrelated-loop: -; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP2]], 0 +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[DOTLCSSA]], 0 ; CHECK-NEXT: br i1 [[TMP4]], label [[END:%.*]], label [[UNRELATED_LOOP]] ; CHECK: end: ; CHECK-NEXT: ret void