Skip to content

Commit

Permalink
[LSR] Split the -lsr-term-fold transformation into it's own pass (llv…
Browse files Browse the repository at this point in the history
…m#104234)

This transformation doesn't actually use any of the internal state of
LSR and recomputes all information from SCEV.  Splitting it out makes
it easier to test.
    
Note that long term I would like to write a version of this transform
which *is* integrated with LSR's solver, but if that happens, we'll
just delete the extra pass.
    
Integration wise, I switched from using TTI to using a pass configuration
variable.  This seems slightly more idiomatic, and means we don't run
the extra logic on any target other than RISCV.
  • Loading branch information
preames authored Aug 18, 2024
1 parent 69115cc commit 27a62ec
Show file tree
Hide file tree
Showing 25 changed files with 438 additions and 291 deletions.
9 changes: 0 additions & 9 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -739,11 +739,6 @@ class TargetTransformInfo {
/// cost should return false, otherwise return true.
bool isNumRegsMajorCostOfLSR() const;

/// Return true if LSR should attempts to replace a use of an otherwise dead
/// primary IV in the latch condition with another IV available in the loop.
/// When successful, makes the primary IV dead.
bool shouldFoldTerminatingConditionAfterLSR() const;

/// Return true if LSR should drop a found solution if it's calculated to be
/// less profitable than the baseline.
bool shouldDropLSRSolutionIfLessProfitable() const;
Expand Down Expand Up @@ -1888,7 +1883,6 @@ class TargetTransformInfo::Concept {
virtual bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
const TargetTransformInfo::LSRCost &C2) = 0;
virtual bool isNumRegsMajorCostOfLSR() = 0;
virtual bool shouldFoldTerminatingConditionAfterLSR() const = 0;
virtual bool shouldDropLSRSolutionIfLessProfitable() const = 0;
virtual bool isProfitableLSRChainElement(Instruction *I) = 0;
virtual bool canMacroFuseCmp() = 0;
Expand Down Expand Up @@ -2367,9 +2361,6 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
bool isNumRegsMajorCostOfLSR() override {
return Impl.isNumRegsMajorCostOfLSR();
}
bool shouldFoldTerminatingConditionAfterLSR() const override {
return Impl.shouldFoldTerminatingConditionAfterLSR();
}
bool shouldDropLSRSolutionIfLessProfitable() const override {
return Impl.shouldDropLSRSolutionIfLessProfitable();
}
Expand Down
2 changes: 0 additions & 2 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,6 @@ class TargetTransformInfoImplBase {

bool isNumRegsMajorCostOfLSR() const { return true; }

bool shouldFoldTerminatingConditionAfterLSR() const { return false; }

bool shouldDropLSRSolutionIfLessProfitable() const { return false; }

bool isProfitableLSRChainElement(Instruction *I) const { return false; }
Expand Down
5 changes: 0 additions & 5 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,6 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
return TargetTransformInfoImplBase::isNumRegsMajorCostOfLSR();
}

bool shouldFoldTerminatingConditionAfterLSR() const {
return TargetTransformInfoImplBase::
shouldFoldTerminatingConditionAfterLSR();
}

bool shouldDropLSRSolutionIfLessProfitable() const {
return TargetTransformInfoImplBase::shouldDropLSRSolutionIfLessProfitable();
}
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/CodeGen/TargetPassConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ class TargetPassConfig : public ImmutablePass {
/// callers.
bool RequireCodeGenSCCOrder = false;

/// Enable LoopTermFold immediately after LSR
bool EnableLoopTermFold = false;

/// Add the actual instruction selection passes. This does not include
/// preparation passes on IR.
bool addCoreISelPasses();
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/InitializePasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ void initializeLoopInfoWrapperPassPass(PassRegistry &);
void initializeLoopPassPass(PassRegistry &);
void initializeLoopSimplifyPass(PassRegistry &);
void initializeLoopStrengthReducePass(PassRegistry &);
void initializeLoopTermFoldPass(PassRegistry &);
void initializeLoopUnrollPass(PassRegistry &);
void initializeLowerAtomicLegacyPassPass(PassRegistry &);
void initializeLowerConstantIntrinsicsPass(PassRegistry &);
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/LinkAllPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ struct ForcePassLinking {
(void)llvm::createLoopExtractorPass();
(void)llvm::createLoopSimplifyPass();
(void)llvm::createLoopStrengthReducePass();
(void)llvm::createLoopTermFoldPass();
(void)llvm::createLoopUnrollPass();
(void)llvm::createLowerGlobalDtorsLegacyPass();
(void)llvm::createLowerInvokePass();
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Passes/MachinePassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ FUNCTION_PASS("win-eh-prepare", WinEHPreparePass())
#define LOOP_PASS(NAME, CREATE_PASS)
#endif
LOOP_PASS("loop-reduce", LoopStrengthReducePass())
LOOP_PASS("loop-term-fold", LoopTermFoldPass())
#undef LOOP_PASS

#ifndef MACHINE_MODULE_PASS
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/Transforms/Scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ Pass *createLICMPass();
//
Pass *createLoopStrengthReducePass();

//===----------------------------------------------------------------------===//
//
// LoopTermFold - This pass attempts to eliminate the last use of an IV in
// a loop terminator instruction by rewriting it in terms of another IV.
// Expected to be run immediately after LSR.
//
Pass *createLoopTermFoldPass();

//===----------------------------------------------------------------------===//
//
// LoopUnroll - This pass is a simple loop unrolling pass.
Expand Down
30 changes: 30 additions & 0 deletions llvm/include/llvm/Transforms/Scalar/LoopTermFold.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===- LoopTermFold.h - Loop Term Fold Pass ---------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_TRANSFORMS_SCALAR_LOOPTERMFOLD_H
#define LLVM_TRANSFORMS_SCALAR_LOOPTERMFOLD_H

#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/IR/PassManager.h"

namespace llvm {

class Loop;
class LPMUpdater;

class LoopTermFoldPass : public PassInfoMixin<LoopTermFoldPass> {
public:
PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR, LPMUpdater &U);
};

} // end namespace llvm

#endif // LLVM_TRANSFORMS_SCALAR_LOOPTERMFOLD_H
4 changes: 0 additions & 4 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,6 @@ bool TargetTransformInfo::isNumRegsMajorCostOfLSR() const {
return TTIImpl->isNumRegsMajorCostOfLSR();
}

bool TargetTransformInfo::shouldFoldTerminatingConditionAfterLSR() const {
return TTIImpl->shouldFoldTerminatingConditionAfterLSR();
}

bool TargetTransformInfo::shouldDropLSRSolutionIfLessProfitable() const {
return TTIImpl->shouldDropLSRSolutionIfLessProfitable();
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/TargetPassConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,8 @@ void TargetPassConfig::addIRPasses() {
if (!DisableLSR) {
addPass(createCanonicalizeFreezeInLoopsPass());
addPass(createLoopStrengthReducePass());
if (EnableLoopTermFold)
addPass(createLoopTermFoldPass());
if (PrintLSR)
addPass(createPrintFunctionPass(dbgs(),
"\n\n*** Code after LSR ***\n"));
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@
#include "llvm/Transforms/Scalar/LoopSimplifyCFG.h"
#include "llvm/Transforms/Scalar/LoopSink.h"
#include "llvm/Transforms/Scalar/LoopStrengthReduce.h"
#include "llvm/Transforms/Scalar/LoopTermFold.h"
#include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h"
#include "llvm/Transforms/Scalar/LoopUnrollPass.h"
#include "llvm/Transforms/Scalar/LoopVersioningLICM.h"
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Passes/PassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ LOOP_PASS("loop-idiom-vectorize", LoopIdiomVectorizePass())
LOOP_PASS("loop-instsimplify", LoopInstSimplifyPass())
LOOP_PASS("loop-predication", LoopPredicationPass())
LOOP_PASS("loop-reduce", LoopStrengthReducePass())
LOOP_PASS("loop-term-fold", LoopTermFoldPass())
LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass())
LOOP_PASS("loop-unroll-full", LoopFullUnrollPass())
LOOP_PASS("loop-versioning-licm", LoopVersioningLICMPass())
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ class RISCVPassConfig : public TargetPassConfig {
if (TM.getOptLevel() != CodeGenOptLevel::None)
substitutePass(&PostRASchedulerID, &PostMachineSchedulerID);
setEnableSinkAndFold(EnableSinkFold);
EnableLoopTermFold = true;
}

RISCVTargetMachine &getRISCVTargetMachine() const {
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,6 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
const TargetTransformInfo::LSRCost &C2);

bool shouldFoldTerminatingConditionAfterLSR() const {
return true;
}
bool
shouldConsiderAddressTypePromotion(const Instruction &I,
bool &AllowPromotionWithoutCommonHeader);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Transforms/Scalar/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ add_llvm_component_library(LLVMScalarOpts
LoopRotation.cpp
LoopSimplifyCFG.cpp
LoopStrengthReduce.cpp
LoopTermFold.cpp
LoopUnrollPass.cpp
LoopUnrollAndJamPass.cpp
LoopVersioningLICM.cpp
Expand Down
Loading

0 comments on commit 27a62ec

Please sign in to comment.