Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SimplifyCFG] Consider a cross signed max-min table in switchToLookupTable #67885

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 66 additions & 28 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6534,18 +6534,20 @@ ShouldBuildLookupTable(SwitchInst *SI, uint64_t TableSize,
}

static bool ShouldUseSwitchConditionAsTableIndex(
ConstantInt &MinCaseVal, const ConstantInt &MaxCaseVal,
const ConstantInt &BeginCaseVal, const ConstantInt &EndCaseVal,
bool HasDefaultResults, const SmallDenseMap<PHINode *, Type *> &ResultTypes,
const DataLayout &DL, const TargetTransformInfo &TTI) {
if (MinCaseVal.isNullValue())
if (BeginCaseVal.isNullValue())
return true;
if (MinCaseVal.isNegative() ||
MaxCaseVal.getLimitedValue() == std::numeric_limits<uint64_t>::max() ||
if (BeginCaseVal.getValue().sge(EndCaseVal.getValue()))
return false;
if (BeginCaseVal.isNegative() ||
EndCaseVal.getLimitedValue() == std::numeric_limits<uint64_t>::max() ||
!HasDefaultResults)
return false;
return all_of(ResultTypes, [&](const auto &KV) {
return SwitchLookupTable::WouldFitInRegister(
DL, MaxCaseVal.getLimitedValue() + 1 /* TableSize */,
DL, EndCaseVal.getLimitedValue() + 1 /* TableSize */,
KV.second /* ResultType */);
});
}
Expand Down Expand Up @@ -6637,7 +6639,8 @@ static void reuseTableCompare(
/// lookup tables.
static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
DomTreeUpdater *DTU, const DataLayout &DL,
const TargetTransformInfo &TTI) {
const TargetTransformInfo &TTI,
bool ConsiderCrossSignedMaxMinTable) {
assert(SI->getNumCases() > 1 && "Degenerate switch?");

BasicBlock *BB = SI->getParent();
Expand All @@ -6663,9 +6666,6 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
// Figure out the corresponding result for each case value and phi node in the
// common destination, as well as the min and max case values.
assert(!SI->cases().empty());
SwitchInst::CaseIt CI = SI->case_begin();
ConstantInt *MinCaseVal = CI->getCaseValue();
ConstantInt *MaxCaseVal = CI->getCaseValue();

BasicBlock *CommonDest = nullptr;

Expand All @@ -6676,17 +6676,51 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
SmallDenseMap<PHINode *, Type *> ResultTypes;
SmallVector<PHINode *, 4> PHIs;

for (SwitchInst::CaseIt E = SI->case_end(); CI != E; ++CI) {
ConstantInt *CaseVal = CI->getCaseValue();
if (CaseVal->getValue().slt(MinCaseVal->getValue()))
MinCaseVal = CaseVal;
if (CaseVal->getValue().sgt(MaxCaseVal->getValue()))
MaxCaseVal = CaseVal;
auto CaseVals = llvm::map_range(
SI->cases(), [](const auto &C) { return C.getCaseValue(); });
auto *SignedMin =
*llvm::min_element(CaseVals, [](const auto *L, const auto *R) {
return L->getValue().slt(R->getValue());
});
auto *SignedMax =
*llvm::max_element(CaseVals, [](const auto *L, const auto *R) {
return L->getValue().slt(R->getValue());
});
auto *UnsignedMin =
*llvm::min_element(CaseVals, [](const auto *L, const auto *R) {
return L->getValue().ult(R->getValue());
});
auto *UnsignedMax =
*llvm::max_element(CaseVals, [](const auto *L, const auto *R) {
return L->getValue().ult(R->getValue());
});
APInt UnsignedDif = UnsignedMax->getValue() - UnsignedMin->getValue();
APInt SignedDif = SignedMax->getValue() - SignedMin->getValue();

ConstantInt *BeginCaseVal = nullptr;
ConstantInt *EndCaseVal = nullptr;
uint64_t LookupTableSize = 0;
bool CrossSignedMaxMinTableSmaller = UnsignedDif.ult(SignedDif);
if (ConsiderCrossSignedMaxMinTable && CrossSignedMaxMinTableSmaller) {
// We consider cases where the starting to the endpoint will cross the
// signed max and min. For example, for the i8 range `[-128, -127, 126,
// 127]`, we choose from 126 to -127. The length of the lookup table is 4.
BeginCaseVal = UnsignedMin;
EndCaseVal = UnsignedMax;
LookupTableSize = UnsignedDif.getLimitedValue() + 1;
} else {
BeginCaseVal = SignedMin;
EndCaseVal = SignedMax;
LookupTableSize = SignedDif.getLimitedValue() + 1;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is looking for max distance between any two values. If so think this can be simplified as:

ConstantInt *BeginCaseVal = nullptr;
ConstantInt *EndCaseVal = nullptr;
uint64_t LookupTableSize = 0;
APInt UnsignedDif = UnsignedMax(CaseVals) - UnsignedMin(CaseVals);
APInt SignedDif = SignedMax(CaseVals) - SignedMin(CaseVals);
if (ConsiderCrossSignedMaxMinTable && UnsignedDif.ult(SignedDif)) {
  // Cross signed min case
  BeginCaseVal = UnsignedMin(CaseVals);
  EndCaseVal = UnsignedMax(CaseVals);
  LookupTableSize = UnsignedDif.getLimitedValue() + 1;
} else {
  // Crossing 0 case
  BeginCaseVal = SignedMin(CaseVals);
  EndCaseVal = SignedMax(CaseVals);
  LookupTableSize = SignedDif.getLimitedValue() + 1;
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to me.


for (const auto &CI : SI->cases()) {
ConstantInt *CaseVal = CI.getCaseValue();

// Resulting value at phi nodes for this case value.
using ResultsTy = SmallVector<std::pair<PHINode *, Constant *>, 4>;
ResultsTy Results;
if (!getCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest,
if (!getCaseResults(SI, CaseVal, CI.getCaseSuccessor(), &CommonDest,
Results, DL, TTI))
return false;

Expand Down Expand Up @@ -6721,13 +6755,12 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
}

bool UseSwitchConditionAsTableIndex = ShouldUseSwitchConditionAsTableIndex(
*MinCaseVal, *MaxCaseVal, HasDefaultResults, ResultTypes, DL, TTI);
*BeginCaseVal, *EndCaseVal, HasDefaultResults, ResultTypes, DL, TTI);
uint64_t TableSize;
if (UseSwitchConditionAsTableIndex)
TableSize = MaxCaseVal->getLimitedValue() + 1;
TableSize = EndCaseVal->getLimitedValue() + 1;
else
TableSize =
(MaxCaseVal->getValue() - MinCaseVal->getValue()).getLimitedValue() + 1;
TableSize = LookupTableSize;

// If the default destination is unreachable, or if the lookup table covers
// all values of the conditional variable, branch directly to the lookup table
Expand Down Expand Up @@ -6757,13 +6790,16 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
}

if (!ShouldBuildLookupTable(SI, TableSize, TTI, DL, ResultTypes))
return false;
// When a signed max-min cannot construct a lookup table, try to find a
// range with a smaller lookup table.
return CrossSignedMaxMinTableSmaller && !ConsiderCrossSignedMaxMinTable &&
SwitchToLookupTable(SI, Builder, DTU, DL, TTI, true);

std::vector<DominatorTree::UpdateType> Updates;

// Compute the maximum table size representable by the integer type we are
// switching upon.
unsigned CaseSize = MinCaseVal->getType()->getPrimitiveSizeInBits();
unsigned CaseSize = BeginCaseVal->getType()->getPrimitiveSizeInBits();
uint64_t MaxTableSize = CaseSize > 63 ? UINT64_MAX : 1ULL << CaseSize;
assert(MaxTableSize >= TableSize &&
"It is impossible for a switch to have more entries than the max "
Expand All @@ -6779,15 +6815,17 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
Value *TableIndex;
ConstantInt *TableIndexOffset;
if (UseSwitchConditionAsTableIndex) {
TableIndexOffset = ConstantInt::get(MaxCaseVal->getIntegerType(), 0);
TableIndexOffset = ConstantInt::get(EndCaseVal->getIntegerType(), 0);
TableIndex = SI->getCondition();
} else {
TableIndexOffset = MinCaseVal;
TableIndexOffset = BeginCaseVal;
// If the default is unreachable, all case values are s>= MinCaseVal. Then
// we can try to attach nsw.
bool MayWrap = true;
if (!DefaultIsReachable) {
APInt Res = MaxCaseVal->getValue().ssub_ov(MinCaseVal->getValue(), MayWrap);
if (!DefaultIsReachable &&
EndCaseVal->getValue().sge(BeginCaseVal->getValue())) {
APInt Res =
EndCaseVal->getValue().ssub_ov(BeginCaseVal->getValue(), MayWrap);
(void)Res;
}

Expand Down Expand Up @@ -6830,7 +6868,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder,
// PHI value for the default case in case we're using a bit mask.
} else {
Value *Cmp = Builder.CreateICmpULT(
TableIndex, ConstantInt::get(MinCaseVal->getType(), TableSize));
TableIndex, ConstantInt::get(BeginCaseVal->getType(), TableSize));
RangeCheckBranch =
Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest());
if (DTU)
Expand Down Expand Up @@ -7145,7 +7183,7 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
// CVP. Therefore, only apply this transformation during late stages of the
// optimisation pipeline.
if (Options.ConvertSwitchToLookupTable &&
SwitchToLookupTable(SI, Builder, DTU, DL, TTI))
SwitchToLookupTable(SI, Builder, DTU, DL, TTI, false))
return requestResimplify();

if (simplifySwitchOfPowersOfTwo(SI, Builder, DL, TTI))
Expand Down
103 changes: 103 additions & 0 deletions llvm/test/Transforms/SimplifyCFG/X86/switch_to_lookup_table.ll
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,109 @@ return:

}

; The cross signed max-min table range is [122, -128]([122, 128]).

define i32 @f_i8_128(i8 %c) {
; CHECK-LABEL: @f_i8_128(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SWITCH_TABLEIDX:%.*]] = sub i8 [[C:%.*]], 122
; CHECK-NEXT: [[TMP0:%.*]] = icmp ult i8 [[SWITCH_TABLEIDX]], 7
; CHECK-NEXT: br i1 [[TMP0]], label [[SWITCH_LOOKUP:%.*]], label [[RETURN:%.*]]
; CHECK: switch.lookup:
; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [7 x i32], ptr @switch.table.f_i8_128, i32 0, i8 [[SWITCH_TABLEIDX]]
; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
; CHECK-NEXT: br label [[RETURN]]
; CHECK: return:
; CHECK-NEXT: [[RETVAL_0:%.*]] = phi i32 [ [[SWITCH_LOAD]], [[SWITCH_LOOKUP]] ], [ 15, [[ENTRY:%.*]] ]
; CHECK-NEXT: ret i32 [[RETVAL_0]]
;
entry:
switch i8 %c, label %sw.default [
i8 122, label %return
i8 123, label %sw.bb1
i8 124, label %sw.bb2
i8 125, label %sw.bb3
i8 126, label %sw.bb4
i8 127, label %sw.bb5
i8 -128, label %sw.bb6
]

sw.bb1: br label %return
sw.bb2: br label %return
sw.bb3: br label %return
sw.bb4: br label %return
sw.bb5: br label %return
sw.bb6: br label %return
sw.default: br label %return
return:
%retval.0 = phi i32 [ 15, %sw.default ], [ 1, %sw.bb6 ], [ 62, %sw.bb5 ], [ 27, %sw.bb4 ], [ -1, %sw.bb3 ], [ 0, %sw.bb2 ], [ 123, %sw.bb1 ], [ 55, %entry ]
ret i32 %retval.0
}

; The cross signed max-min table range is [3, 0].

define i32 @f_min_max(i3 %c) {
; CHECK-LABEL: @f_min_max(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SWITCH_TABLEIDX:%.*]] = sub i3 [[C:%.*]], -4
; CHECK-NEXT: [[SWITCH_TABLEIDX_ZEXT:%.*]] = zext i3 [[SWITCH_TABLEIDX]] to i4
; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [8 x i32], ptr @switch.table.f_min_max, i32 0, i4 [[SWITCH_TABLEIDX_ZEXT]]
; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
; CHECK-NEXT: ret i32 [[SWITCH_LOAD]]
;
entry:
switch i3 %c, label %sw.default [
i3 -4, label %return
i3 -3, label %sw.bb1
i3 -2, label %sw.bb2
i3 -1, label %sw.bb3
i3 0, label %sw.bb4
i3 3, label %sw.bb6
]

sw.bb1: br label %return
sw.bb2: br label %return
sw.bb3: br label %return
sw.bb4: br label %return
sw.bb6: br label %return
sw.default: br label %return
return:
%retval.0 = phi i32 [ 15, %sw.default ], [ 1, %sw.bb6 ], [ 27, %sw.bb4 ], [ -1, %sw.bb3 ], [ 0, %sw.bb2 ], [ 123, %sw.bb1 ], [ 55, %entry ]
ret i32 %retval.0
}

; The cross signed max-min table range is [-1, -4].

define i32 @f_min_max_2(i3 %c) {
; CHECK-LABEL: @f_min_max_2(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SWITCH_TABLEIDX:%.*]] = sub i3 [[C:%.*]], -4
; CHECK-NEXT: [[SWITCH_TABLEIDX_ZEXT:%.*]] = zext i3 [[SWITCH_TABLEIDX]] to i4
; CHECK-NEXT: [[SWITCH_GEP:%.*]] = getelementptr inbounds [8 x i32], ptr @switch.table.f_min_max_2, i32 0, i4 [[SWITCH_TABLEIDX_ZEXT]]
; CHECK-NEXT: [[SWITCH_LOAD:%.*]] = load i32, ptr [[SWITCH_GEP]], align 4
; CHECK-NEXT: ret i32 [[SWITCH_LOAD]]
;
entry:
switch i3 %c, label %sw.default [
i3 -1, label %return
i3 0, label %sw.bb1
i3 1, label %sw.bb2
i3 2, label %sw.bb3
i3 3, label %sw.bb4
i3 -4, label %sw.bb6
]

sw.bb1: br label %return
sw.bb2: br label %return
sw.bb3: br label %return
sw.bb4: br label %return
sw.bb6: br label %return
sw.default: br label %return
return:
%retval.0 = phi i32 [ 15, %sw.default ], [ 1, %sw.bb6 ], [ 27, %sw.bb4 ], [ -1, %sw.bb3 ], [ 0, %sw.bb2 ], [ 123, %sw.bb1 ], [ 55, %entry ]
ret i32 %retval.0
}

; A switch used to initialize two variables, an i8 and a float.

declare void @dummy(i8 signext, float)
Expand Down
Loading