Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Decompose more icmps into masks #110836

Merged
merged 4 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions llvm/include/llvm/Analysis/CmpInstAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,21 @@ namespace llvm {
Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
CmpInst::Predicate &Pred);

/// Represents the operation icmp (X & Mask) pred 0, where pred can only be
/// Represents the operation icmp (X & Mask) pred C, where pred can only be
/// eq or ne.
struct DecomposedBitTest {
Value *X;
CmpInst::Predicate Pred;
APInt Mask;
APInt C;
};

/// Decompose an icmp into the form ((X & Mask) pred 0) if possible.
/// Decompose an icmp into the form ((X & Mask) pred C) if possible.
/// Unless \p AllowNonZeroC is true, C will always be 0.
std::optional<DecomposedBitTest>
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
bool LookThroughTrunc = true);
bool LookThroughTrunc = true,
bool AllowNonZeroC = false);

} // end namespace llvm

Expand Down
60 changes: 48 additions & 12 deletions llvm/lib/Analysis/CmpInstAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,

std::optional<DecomposedBitTest>
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
bool LookThruTrunc) {
bool LookThruTrunc, bool AllowNonZeroC) {
using namespace PatternMatch;

const APInt *OrigC;
Expand All @@ -100,29 +100,65 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
switch (Pred) {
default:
llvm_unreachable("Unexpected predicate");
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLT: {
// X < 0 is equivalent to (X & SignMask) != 0.
if (!C.isZero())
return std::nullopt;
Result.Mask = APInt::getSignMask(C.getBitWidth());
Result.Pred = ICmpInst::ICMP_NE;
break;
if (C.isZero()) {
Result.Mask = APInt::getSignMask(C.getBitWidth());
Result.C = APInt::getZero(C.getBitWidth());
Result.Pred = ICmpInst::ICMP_NE;
break;
}

APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth());
if (FlippedSign.isPowerOf2()) {
// X s< 10000100 is equivalent to (X & 11111100 == 10000000)
Result.Mask = -FlippedSign;
Result.C = APInt::getSignMask(C.getBitWidth());
Result.Pred = ICmpInst::ICMP_EQ;
break;
}

if (FlippedSign.isNegatedPowerOf2()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is covered by and_sgt_to_mask. (It uses the inverse predicate so we get EQ instead of NE as a result.)

// X s< 01111100 is equivalent to (X & 11111100 != 01111100)
Result.Mask = FlippedSign;
Result.C = C;
Result.Pred = ICmpInst::ICMP_NE;
break;
}

return std::nullopt;
}
case ICmpInst::ICMP_ULT:
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
if (!C.isPowerOf2())
return std::nullopt;
Result.Mask = -C;
Result.Pred = ICmpInst::ICMP_EQ;
break;
if (C.isPowerOf2()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is pre-existing, just moved into the if.

Result.Mask = -C;
Result.C = APInt::getZero(C.getBitWidth());
Result.Pred = ICmpInst::ICMP_EQ;
break;
}

// X u< 11111100 is equivalent to (X & 11111100 != 11111100)
if (C.isNegatedPowerOf2()) {
Result.Mask = C;
Result.C = C;
Result.Pred = ICmpInst::ICMP_NE;
break;
}

return std::nullopt;
}

if (!AllowNonZeroC && !Result.C.isZero())
return std::nullopt;

if (Inverted)
Result.Pred = ICmpInst::getInversePredicate(Result.Pred);

Value *X;
if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
Result.X = X;
Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
Result.C = Result.C.zext(X->getType()->getScalarSizeInBits());
} else {
Result.X = LHS;
}
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,15 @@ static unsigned conjugateICmpMask(unsigned Mask) {
// Adapts the external decomposeBitTestICmp for local use.
static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
Value *&X, Value *&Y, Value *&Z) {
auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred);
auto Res = llvm::decomposeBitTestICmp(
LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroC=*/true);
if (!Res)
return false;

Pred = Res->Pred;
X = Res->X;
Y = ConstantInt::get(X->getType(), Res->Mask);
Z = ConstantInt::get(X->getType(), 0);
Z = ConstantInt::get(X->getType(), Res->C);
return true;
}

Expand Down
26 changes: 5 additions & 21 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5932,31 +5932,15 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
return nullptr;

// This matches patterns corresponding to tests of the signbit as well as:
// (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
// (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) {
// (trunc X) pred C2 --> (X & Mask) == C
if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true,
/*AllowNonZeroC=*/true)) {
Value *And = Builder.CreateAnd(Res->X, Res->Mask);
Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
return new ICmpInst(Res->Pred, And, Zero);
Constant *C = ConstantInt::get(Res->X->getType(), Res->C);
return new ICmpInst(Res->Pred, And, C);
}

unsigned SrcBits = X->getType()->getScalarSizeInBits();
if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) {
// If C is a negative power-of-2 (high-bit mask):
// (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?)
Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits));
Value *And = Builder.CreateAnd(X, MaskC);
return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC);
}

if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) {
// If C is not-of-power-of-2 (one clear bit):
// (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?)
Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits));
Value *And = Builder.CreateAnd(X, MaskC);
return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC);
}

if (auto *II = dyn_cast<IntrinsicInst>(X)) {
if (II->getIntrinsicID() == Intrinsic::cttz ||
II->getIntrinsicID() == Intrinsic::ctlz) {
Expand Down
17 changes: 5 additions & 12 deletions llvm/test/Transforms/InstCombine/and-or-icmps.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3335,10 +3335,7 @@ define i1 @icmp_eq_or_z_or_pow2orz_fail_bad_pred2(i8 %x, i8 %y) {

define i1 @and_slt_to_mask(i8 %x) {
; CHECK-LABEL: @and_slt_to_mask(
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[X:%.*]], -124
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
; CHECK-NEXT: [[AND2:%.*]] = icmp slt i8 [[X:%.*]], -126
; CHECK-NEXT: ret i1 [[AND2]]
;
%cmp = icmp slt i8 %x, -124
Expand All @@ -3365,10 +3362,8 @@ define i1 @and_slt_to_mask_off_by_one(i8 %x) {

define i1 @and_sgt_to_mask(i8 %x) {
; CHECK-LABEL: @and_sgt_to_mask(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 123
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2
; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], 124
; CHECK-NEXT: ret i1 [[AND2]]
;
%cmp = icmp sgt i8 %x, 123
Expand All @@ -3395,10 +3390,8 @@ define i1 @and_sgt_to_mask_off_by_one(i8 %x) {

define i1 @and_ugt_to_mask(i8 %x) {
; CHECK-LABEL: @and_ugt_to_mask(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -5
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2
; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -4
; CHECK-NEXT: ret i1 [[AND2]]
;
%cmp = icmp ugt i8 %x, -5
Expand Down
Loading