Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoopVectorize] Refine runtime memory check costs when there is an outer loop #76034

Merged
merged 6 commits into from
Jan 26, 2024
62 changes: 56 additions & 6 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,8 @@ class GeneratedRTChecks {
bool CostTooHigh = false;
const bool AddBranchWeights;

Loop *OuterLoop = nullptr;

public:
GeneratedRTChecks(ScalarEvolution &SE, DominatorTree *DT, LoopInfo *LI,
TargetTransformInfo *TTI, const DataLayout &DL,
Expand Down Expand Up @@ -2053,6 +2055,9 @@ class GeneratedRTChecks {
DT->eraseNode(SCEVCheckBlock);
LI->removeBlock(SCEVCheckBlock);
}

// Outer loop is used as part of the later cost calculations.
OuterLoop = L->getParentLoop();
}

InstructionCost getCost() {
Expand All @@ -2076,16 +2081,61 @@ class GeneratedRTChecks {
LLVM_DEBUG(dbgs() << " " << C << " for " << I << "\n");
RTCheckCost += C;
}
if (MemCheckBlock)
if (MemCheckBlock) {
InstructionCost MemCheckCost = 0;
for (Instruction &I : *MemCheckBlock) {
if (MemCheckBlock->getTerminator() == &I)
continue;
InstructionCost C =
TTI->getInstructionCost(&I, TTI::TCK_RecipThroughput);
LLVM_DEBUG(dbgs() << " " << C << " for " << I << "\n");
RTCheckCost += C;
MemCheckCost += C;
}

// If the runtime memory checks are being created inside an outer loop
// we should find out if these checks are outer loop invariant. If so,
// the checks will likely be hoisted out and so the effective cost will
// reduce according to the outer loop trip count.
if (OuterLoop) {
ScalarEvolution *SE = MemCheckExp.getSE();
// TODO: If profitable, we could refine this further by analysing every
// individual memory check, since there could be a mixture of loop
// variant and invariant checks that mean the final condition is
// variant.
const SCEV *Cond = SE->getSCEV(MemRuntimeCheckCond);
david-arm marked this conversation as resolved.
Show resolved Hide resolved
if (SE->isLoopInvariant(Cond, OuterLoop)) {
// It seems reasonable to assume that we can reduce the effective
// cost of the checks even when we know nothing about the trip
// count. Assume that the outer loop executes at least twice.
unsigned BestTripCount = 2;

// If exact trip count is known use that.
if (unsigned SmallTC = SE->getSmallConstantTripCount(OuterLoop))
BestTripCount = SmallTC;
else if (LoopVectorizeWithBlockFrequency) {
// Else use profile data if available.
if (auto EstimatedTC = getLoopEstimatedTripCount(OuterLoop))
BestTripCount = *EstimatedTC;
}

InstructionCost NewMemCheckCost = MemCheckCost / BestTripCount;

// Let's ensure the cost is always at least 1.
NewMemCheckCost = std::max(*NewMemCheckCost.getValue(),
(InstructionCost::CostType)1);

LLVM_DEBUG(dbgs()
<< "We expect runtime memory checks to be hoisted "
<< "out of the outer loop. Cost reduced from "
<< MemCheckCost << " to " << NewMemCheckCost << '\n');

MemCheckCost = NewMemCheckCost;
}
}

RTCheckCost += MemCheckCost;
}

if (SCEVCheckBlock || MemCheckBlock)
LLVM_DEBUG(dbgs() << "Total cost of runtime checks: " << RTCheckCost
david-arm marked this conversation as resolved.
Show resolved Hide resolved
<< "\n");
Expand Down Expand Up @@ -2144,8 +2194,8 @@ class GeneratedRTChecks {

BranchInst::Create(LoopVectorPreHeader, SCEVCheckBlock);
// Create new preheader for vector loop.
if (auto *PL = LI->getLoopFor(LoopVectorPreHeader))
PL->addBasicBlockToLoop(SCEVCheckBlock, *LI);
if (OuterLoop)
OuterLoop->addBasicBlockToLoop(SCEVCheckBlock, *LI);

SCEVCheckBlock->getTerminator()->eraseFromParent();
SCEVCheckBlock->moveBefore(LoopVectorPreHeader);
Expand Down Expand Up @@ -2179,8 +2229,8 @@ class GeneratedRTChecks {
DT->changeImmediateDominator(LoopVectorPreHeader, MemCheckBlock);
MemCheckBlock->moveBefore(LoopVectorPreHeader);

if (auto *PL = LI->getLoopFor(LoopVectorPreHeader))
PL->addBasicBlockToLoop(MemCheckBlock, *LI);
if (OuterLoop)
OuterLoop->addBasicBlockToLoop(MemCheckBlock, *LI);

BranchInst &BI =
*BranchInst::Create(Bypass, LoopVectorPreHeader, MemRuntimeCheckCond);
Expand Down
217 changes: 217 additions & 0 deletions llvm/test/Transforms/LoopVectorize/AArch64/low_trip_memcheck_cost.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
; REQUIRES: asserts
; RUN: opt -p loop-vectorize -debug-only=loop-vectorize -S -disable-output < %s 2>&1 | FileCheck %s

david-arm marked this conversation as resolved.
Show resolved Hide resolved
target triple = "aarch64-unknown-linux-gnu"

define void @no_outer_loop(ptr nocapture noundef %a, ptr nocapture noundef readonly %b, i64 noundef %off, i64 noundef %n) {
; CHECK-LABEL: LV: Checking a loop in 'no_outer_loop'
; CHECK: Calculating cost of runtime checks:
; CHECK-NOT: We expect runtime memory checks to be hoisted out of the outer loop.
; CHECK: Total cost of runtime checks: 4
; CHECK-NEXT: LV: Minimum required TC for runtime checks to be profitable:16
entry:
br label %inner.loop

inner.loop:
%inner.iv = phi i64 [ 0, %entry ], [ %inner.iv.next, %inner.loop ]
%add.us = add nuw nsw i64 %inner.iv, %off
%arrayidx.us = getelementptr inbounds i8, ptr %b, i64 %add.us
%0 = load i8, ptr %arrayidx.us, align 1
%arrayidx7.us = getelementptr inbounds i8, ptr %a, i64 %add.us
%1 = load i8, ptr %arrayidx7.us, align 1
%add9.us = add i8 %1, %0
store i8 %add9.us, ptr %arrayidx7.us, align 1
%inner.iv.next = add nuw nsw i64 %inner.iv, 1
%exitcond.not = icmp eq i64 %inner.iv.next, %n
br i1 %exitcond.not, label %inner.exit, label %inner.loop

inner.exit:
ret void
}

define void @outer_no_tc(ptr nocapture noundef %a, ptr nocapture noundef readonly %b, i64 noundef %m, i64 noundef %n) {
david-arm marked this conversation as resolved.
Show resolved Hide resolved
; CHECK-LABEL: LV: Checking a loop in 'outer_no_tc'
; CHECK: Calculating cost of runtime checks:
; CHECK: We expect runtime memory checks to be hoisted out of the outer loop. Cost reduced from 6 to 3
; CHECK: Total cost of runtime checks: 3
; CHECK-NEXT: LV: Minimum required TC for runtime checks to be profitable:16
entry:
br label %outer.loop

outer.loop:
%outer.iv = phi i64 [ %outer.iv.next, %inner.exit ], [ 0, %entry ]
%mul.us = mul nsw i64 %outer.iv, %n
br label %inner.loop

inner.loop:
%inner.iv = phi i64 [ 0, %outer.loop ], [ %inner.iv.next, %inner.loop ]
%add.us = add nuw nsw i64 %inner.iv, %mul.us
%arrayidx.us = getelementptr inbounds i8, ptr %b, i64 %add.us
%0 = load i8, ptr %arrayidx.us, align 1
%arrayidx7.us = getelementptr inbounds i8, ptr %a, i64 %add.us
%1 = load i8, ptr %arrayidx7.us, align 1
%add9.us = add i8 %1, %0
store i8 %add9.us, ptr %arrayidx7.us, align 1
%inner.iv.next = add nuw nsw i64 %inner.iv, 1
%exitcond.not = icmp eq i64 %inner.iv.next, %n
br i1 %exitcond.not, label %inner.exit, label %inner.loop

inner.exit:
%outer.iv.next = add nuw nsw i64 %outer.iv, 1
%exitcond27.not = icmp eq i64 %outer.iv.next, %m
br i1 %exitcond27.not, label %outer.exit, label %outer.loop

outer.exit:
ret void
}


define void @outer_known_tc3(ptr nocapture noundef %a, ptr nocapture noundef readonly %b, i64 noundef %n) {
; CHECK-LABEL: LV: Checking a loop in 'outer_known_tc3'
; CHECK: Calculating cost of runtime checks:
; CHECK: We expect runtime memory checks to be hoisted out of the outer loop. Cost reduced from 6 to 2
; CHECK: Total cost of runtime checks: 2
; CHECK-NEXT: LV: Minimum required TC for runtime checks to be profitable:16
entry:
br label %outer.loop

outer.loop:
%outer.iv = phi i64 [ %outer.iv.next, %inner.exit ], [ 0, %entry ]
%mul.us = mul nsw i64 %outer.iv, %n
br label %inner.loop

inner.loop:
%inner.iv = phi i64 [ 0, %outer.loop ], [ %inner.iv.next, %inner.loop ]
%add.us = add nuw nsw i64 %inner.iv, %mul.us
%arrayidx.us = getelementptr inbounds i8, ptr %b, i64 %add.us
%0 = load i8, ptr %arrayidx.us, align 1
%arrayidx7.us = getelementptr inbounds i8, ptr %a, i64 %add.us
%1 = load i8, ptr %arrayidx7.us, align 1
%add9.us = add i8 %1, %0
store i8 %add9.us, ptr %arrayidx7.us, align 1
%inner.iv.next = add nuw nsw i64 %inner.iv, 1
%exitcond.not = icmp eq i64 %inner.iv.next, %n
br i1 %exitcond.not, label %inner.exit, label %inner.loop

inner.exit:
%outer.iv.next = add nuw nsw i64 %outer.iv, 1
%exitcond26.not = icmp eq i64 %outer.iv.next, 3
br i1 %exitcond26.not, label %outer.exit, label %outer.loop

outer.exit:
ret void
}


define void @outer_known_tc64(ptr nocapture noundef %a, ptr nocapture noundef readonly %b, i64 noundef %n) {
; CHECK-LABEL: LV: Checking a loop in 'outer_known_tc64'
; CHECK: Calculating cost of runtime checks:
; CHECK: We expect runtime memory checks to be hoisted out of the outer loop. Cost reduced from 6 to 1
; CHECK: Total cost of runtime checks: 1
; CHECK-NEXT: LV: Minimum required TC for runtime checks to be profitable:16
entry:
br label %outer.loop

outer.loop:
%outer.iv = phi i64 [ %outer.iv.next, %inner.exit ], [ 0, %entry ]
%mul.us = mul nsw i64 %outer.iv, %n
br label %inner.loop

inner.loop:
%inner.iv = phi i64 [ 0, %outer.loop ], [ %inner.iv.next, %inner.loop ]
%add.us = add nuw nsw i64 %inner.iv, %mul.us
%arrayidx.us = getelementptr inbounds i8, ptr %b, i64 %add.us
%0 = load i8, ptr %arrayidx.us, align 1
%arrayidx7.us = getelementptr inbounds i8, ptr %a, i64 %add.us
%1 = load i8, ptr %arrayidx7.us, align 1
%add9.us = add i8 %1, %0
store i8 %add9.us, ptr %arrayidx7.us, align 1
%inner.iv.next = add nuw nsw i64 %inner.iv, 1
%exitcond.not = icmp eq i64 %inner.iv.next, %n
br i1 %exitcond.not, label %inner.exit, label %inner.loop

inner.exit:
%outer.iv.next = add nuw nsw i64 %outer.iv, 1
%exitcond26.not = icmp eq i64 %outer.iv.next, 64
br i1 %exitcond26.not, label %outer.exit, label %outer.loop

outer.exit:
ret void
}


define void @outer_pgo_3(ptr nocapture noundef %a, ptr nocapture noundef readonly %b, i64 noundef %m, i64 noundef %n) {
; CHECK-LABEL: LV: Checking a loop in 'outer_pgo_3'
; CHECK: Calculating cost of runtime checks:
; CHECK: We expect runtime memory checks to be hoisted out of the outer loop. Cost reduced from 6 to 2
; CHECK: Total cost of runtime checks: 2
; CHECK-NEXT: LV: Minimum required TC for runtime checks to be profitable:16
entry:
br label %outer.loop

outer.loop:
%outer.iv = phi i64 [ %outer.iv.next, %inner.exit ], [ 0, %entry ]
%mul.us = mul nsw i64 %outer.iv, %n
br label %inner.loop

inner.loop:
%inner.iv = phi i64 [ 0, %outer.loop ], [ %inner.iv.next, %inner.loop ]
%add.us = add nuw nsw i64 %inner.iv, %mul.us
%arrayidx.us = getelementptr inbounds i8, ptr %b, i64 %add.us
%0 = load i8, ptr %arrayidx.us, align 1
%arrayidx7.us = getelementptr inbounds i8, ptr %a, i64 %add.us
%1 = load i8, ptr %arrayidx7.us, align 1
%add9.us = add i8 %1, %0
store i8 %add9.us, ptr %arrayidx7.us, align 1
%inner.iv.next = add nuw nsw i64 %inner.iv, 1
%exitcond.not = icmp eq i64 %inner.iv.next, %n
br i1 %exitcond.not, label %inner.exit, label %inner.loop

inner.exit:
%outer.iv.next = add nuw nsw i64 %outer.iv, 1
%exitcond26.not = icmp eq i64 %outer.iv.next, %m
br i1 %exitcond26.not, label %outer.exit, label %outer.loop, !prof !0

outer.exit:
ret void
}


define void @outer_known_tc3_full_range_checks(ptr nocapture noundef %dst, ptr nocapture noundef readonly %src, i64 noundef %n) {
; CHECK-LABEL: LV: Checking a loop in 'outer_known_tc3_full_range_checks'
; CHECK: Calculating cost of runtime checks:
; CHECK: We expect runtime memory checks to be hoisted out of the outer loop. Cost reduced from 6 to 2
; CHECK: Total cost of runtime checks: 2
; CHECK-NEXT: LV: Minimum required TC for runtime checks to be profitable:4
entry:
br label %outer.loop

outer.loop:
%outer.iv = phi i64 [ 0, %entry ], [ %outer.iv.next, %inner.exit ]
%0 = mul nsw i64 %outer.iv, %n
br label %inner.loop

inner.loop:
%iv.inner = phi i64 [ 0, %outer.loop ], [ %iv.inner.next, %inner.loop ]
%1 = add nuw nsw i64 %iv.inner, %0
%arrayidx.us = getelementptr inbounds i32, ptr %src, i64 %1
%2 = load i32, ptr %arrayidx.us, align 4
%arrayidx8.us = getelementptr inbounds i32, ptr %dst, i64 %1
%3 = load i32, ptr %arrayidx8.us, align 4
%add9.us = add nsw i32 %3, %2
store i32 %add9.us, ptr %arrayidx8.us, align 4
%iv.inner.next = add nuw nsw i64 %iv.inner, 1
%inner.exit.cond = icmp eq i64 %iv.inner.next, %n
br i1 %inner.exit.cond, label %inner.exit, label %inner.loop

inner.exit:
%outer.iv.next = add nuw nsw i64 %outer.iv, 1
%outer.exit.cond = icmp eq i64 %outer.iv.next, 3
br i1 %outer.exit.cond, label %outer.exit, label %outer.loop

outer.exit:
ret void
}


!0 = !{!"branch_weights", i32 10, i32 20}
Loading