Skip to content

Commit

Permalink
[ValueTracking] Compute known FPClass from dominating condition (#80941)
Browse files Browse the repository at this point in the history
This patch improves `computeKnownFPClass` by using context-sensitive
information from `DomConditionCache`.
  • Loading branch information
dtcxzyw authored Feb 13, 2024
1 parent 95a204c commit 542a3cb
Show file tree
Hide file tree
Showing 3 changed files with 392 additions and 28 deletions.
7 changes: 6 additions & 1 deletion llvm/lib/Analysis/DomConditionCache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static void findAffectedValues(Value *Cond,
if (!Visited.insert(V).second)
continue;

ICmpInst::Predicate Pred;
CmpInst::Predicate Pred;
Value *A, *B;
// Only recurse into and/or if it matches the top-level and/or type.
if (TopLevelIsAnd ? match(V, m_LogicalAnd(m_Value(A), m_Value(B)))
Expand All @@ -67,6 +67,11 @@ static void findAffectedValues(Value *Cond,
if (match(A, m_Add(m_Value(X), m_ConstantInt())))
AddAffected(X);
}
} else if (match(Cond, m_CombineOr(m_FCmp(Pred, m_Value(A), m_Constant()),
m_Intrinsic<Intrinsic::is_fpclass>(
m_Value(A), m_Constant())))) {
// Handle patterns that computeKnownFPClass() support.
AddAffected(A);
}
}
}
Expand Down
91 changes: 64 additions & 27 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4225,9 +4225,53 @@ llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
return fcmpImpliesClass(Pred, F, LHS, *ConstRHS, LookThroughSrc);
}

static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
const SimplifyQuery &Q) {
FPClassTest KnownFromAssume = fcAllFlags;
static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
bool CondIsTrue,
const Instruction *CxtI,
KnownFPClass &KnownFromContext) {
CmpInst::Predicate Pred;
Value *LHS;
uint64_t ClassVal = 0;
const APFloat *CRHS;
// TODO: handle sign-bit check idiom
if (match(Cond, m_FCmp(Pred, m_Value(LHS), m_APFloat(CRHS)))) {
auto [CmpVal, MaskIfTrue, MaskIfFalse] = fcmpImpliesClass(
Pred, *CxtI->getParent()->getParent(), LHS, *CRHS, LHS != V);
if (CmpVal == V)
KnownFromContext.knownNot(~(CondIsTrue ? MaskIfTrue : MaskIfFalse));
} else if (match(Cond, m_Intrinsic<Intrinsic::is_fpclass>(
m_Value(LHS), m_ConstantInt(ClassVal)))) {
FPClassTest Mask = static_cast<FPClassTest>(ClassVal);
KnownFromContext.knownNot(CondIsTrue ? ~Mask : Mask);
}
}

static KnownFPClass computeKnownFPClassFromContext(const Value *V,
const SimplifyQuery &Q) {
KnownFPClass KnownFromContext;

if (!Q.CxtI)
return KnownFromContext;

if (Q.DC && Q.DT) {
// Handle dominating conditions.
for (BranchInst *BI : Q.DC->conditionsFor(V)) {
Value *Cond = BI->getCondition();

BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(0));
if (Q.DT->dominates(Edge0, Q.CxtI->getParent()))
computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/true, Q.CxtI,
KnownFromContext);

BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(1));
if (Q.DT->dominates(Edge1, Q.CxtI->getParent()))
computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/false, Q.CxtI,
KnownFromContext);
}
}

if (!Q.AC)
return KnownFromContext;

// Try to restrict the floating-point classes based on information from
// assumptions.
Expand All @@ -4245,25 +4289,11 @@ static FPClassTest computeKnownFPClassFromAssumes(const Value *V,
if (!isValidAssumeForContext(I, Q.CxtI, Q.DT))
continue;

CmpInst::Predicate Pred;
Value *LHS, *RHS;
uint64_t ClassVal = 0;
if (match(I->getArgOperand(0), m_FCmp(Pred, m_Value(LHS), m_Value(RHS)))) {
const APFloat *CRHS;
if (match(RHS, m_APFloat(CRHS))) {
auto [CmpVal, MaskIfTrue, MaskIfFalse] =
fcmpImpliesClass(Pred, *F, LHS, *CRHS, LHS != V);
if (CmpVal == V)
KnownFromAssume &= MaskIfTrue;
}
} else if (match(I->getArgOperand(0),
m_Intrinsic<Intrinsic::is_fpclass>(
m_Value(LHS), m_ConstantInt(ClassVal)))) {
KnownFromAssume &= static_cast<FPClassTest>(ClassVal);
}
computeKnownFPClassFromCond(V, I->getArgOperand(0), /*CondIsTrue=*/true,
Q.CxtI, KnownFromContext);
}

return KnownFromAssume;
return KnownFromContext;
}

void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
Expand Down Expand Up @@ -4371,17 +4401,21 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
KnownNotFromFlags |= fcInf;
}

if (Q.AC) {
FPClassTest AssumedClasses = computeKnownFPClassFromAssumes(V, Q);
KnownNotFromFlags |= ~AssumedClasses;
}
KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;

// We no longer need to find out about these bits from inputs if we can
// assume this from flags/attributes.
InterestedClasses &= ~KnownNotFromFlags;

auto ClearClassesFromFlags = make_scope_exit([=, &Known] {
Known.knownNot(KnownNotFromFlags);
if (!Known.SignBit && AssumedClasses.SignBit) {
if (*AssumedClasses.SignBit)
Known.signBitMustBeOne();
else
Known.signBitMustBeZero();
}
});

if (!Op)
Expand Down Expand Up @@ -5283,7 +5317,8 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,

bool First = true;

for (Value *IncValue : P->incoming_values()) {
for (const Use &U : P->operands()) {
Value *IncValue = U.get();
// Skip direct self references.
if (IncValue == P)
continue;
Expand All @@ -5292,8 +5327,10 @@ void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
// Recurse, but cap the recursion to two levels, because we don't want
// to waste time spinning around in loops. We need at least depth 2 to
// detect known sign bits.
computeKnownFPClass(IncValue, DemandedElts, InterestedClasses, KnownSrc,
PhiRecursionLimit, Q);
computeKnownFPClass(
IncValue, DemandedElts, InterestedClasses, KnownSrc,
PhiRecursionLimit,
Q.getWithInstruction(P->getIncomingBlock(U)->getTerminator()));

if (First) {
Known = KnownSrc;
Expand Down
Loading

0 comments on commit 542a3cb

Please sign in to comment.