Skip to content

Commit

Permalink
[InstCombine] Fold icmp of trunc nuw/nsw (#90436)
Browse files Browse the repository at this point in the history
Convert the existing foldICmpTruncWithTruncOrExt() fold to work with
trunc nowrap flags instead of computeKnownBits(). This also allows us to
generalize the fold to work with signed comparisons.

Interestingly, apart from the obvious combinations like signed
predicates with trunc nsw, some non-obvious ones are also legal. For
example for unsigned predicates we can do the transform for two trunc
nsw as well (rather than only trunc nuw).

Proofs: https://alive2.llvm.org/ce/z/ndewwK
  • Loading branch information
nikic authored May 3, 2024
1 parent fc83eda commit b0eeacb
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 76 deletions.
27 changes: 27 additions & 0 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1839,6 +1839,19 @@ template <typename Op_t> struct NNegZExt_match {
}
};

template <typename Op_t, unsigned WrapFlags = 0> struct NoWrapTrunc_match {
Op_t Op;

NoWrapTrunc_match(const Op_t &OpMatch) : Op(OpMatch) {}

template <typename OpTy> bool match(OpTy *V) {
if (auto *I = dyn_cast<TruncInst>(V))
return (I->getNoWrapKind() & WrapFlags) == WrapFlags &&
Op.match(I->getOperand(0));
return false;
}
};

/// Matches BitCast.
template <typename OpTy>
inline CastOperator_match<OpTy, Instruction::BitCast>
Expand Down Expand Up @@ -1900,6 +1913,20 @@ inline CastOperator_match<OpTy, Instruction::Trunc> m_Trunc(const OpTy &Op) {
return CastOperator_match<OpTy, Instruction::Trunc>(Op);
}

/// Matches trunc nuw.
template <typename OpTy>
inline NoWrapTrunc_match<OpTy, TruncInst::NoUnsignedWrap>
m_NUWTrunc(const OpTy &Op) {
return NoWrapTrunc_match<OpTy, TruncInst::NoUnsignedWrap>(Op);
}

/// Matches trunc nsw.
template <typename OpTy>
inline NoWrapTrunc_match<OpTy, TruncInst::NoSignedWrap>
m_NSWTrunc(const OpTy &Op) {
return NoWrapTrunc_match<OpTy, TruncInst::NoSignedWrap>(Op);
}

template <typename OpTy>
inline match_combine_or<CastOperator_match<OpTy, Instruction::Trunc>, OpTy>
m_TruncOrSelf(const OpTy &Op) {
Expand Down
55 changes: 30 additions & 25 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1479,19 +1479,29 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
return nullptr;
}

/// Fold icmp (trunc X), (trunc Y).
/// Fold icmp (trunc X), (zext Y).
/// Fold icmp (trunc nuw/nsw X), (trunc nuw/nsw Y).
/// Fold icmp (trunc nuw/nsw X), (zext/sext Y).
Instruction *
InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp,
const SimplifyQuery &Q) {
if (Cmp.isSigned())
return nullptr;

Value *X, *Y;
ICmpInst::Predicate Pred;
bool YIsZext = false;
bool YIsSExt = false;
// Try to match icmp (trunc X), (trunc Y)
if (match(&Cmp, m_ICmp(Pred, m_Trunc(m_Value(X)), m_Trunc(m_Value(Y))))) {
unsigned NoWrapFlags = cast<TruncInst>(Cmp.getOperand(0))->getNoWrapKind() &
cast<TruncInst>(Cmp.getOperand(1))->getNoWrapKind();
if (Cmp.isSigned()) {
// For signed comparisons, both truncs must be nsw.
if (!(NoWrapFlags & TruncInst::NoSignedWrap))
return nullptr;
} else {
// For unsigned and equality comparisons, either both must be nuw or
// both must be nsw, we don't care which.
if (!NoWrapFlags)
return nullptr;
}

if (X->getType() != Y->getType() &&
(!Cmp.getOperand(0)->hasOneUse() || !Cmp.getOperand(1)->hasOneUse()))
return nullptr;
Expand All @@ -1501,12 +1511,19 @@ InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp,
Pred = Cmp.getSwappedPredicate(Pred);
}
}
// Try to match icmp (trunc X), (zext Y)
else if (match(&Cmp, m_c_ICmp(Pred, m_Trunc(m_Value(X)),
m_OneUse(m_ZExt(m_Value(Y))))))

YIsZext = true;
else
// Try to match icmp (trunc nuw X), (zext Y)
else if (!Cmp.isSigned() &&
match(&Cmp, m_c_ICmp(Pred, m_NUWTrunc(m_Value(X)),
m_OneUse(m_ZExt(m_Value(Y)))))) {
// Can fold trunc nuw + zext for unsigned and equality predicates.
}
// Try to match icmp (trunc nsw X), (sext Y)
else if (match(&Cmp, m_c_ICmp(Pred, m_NSWTrunc(m_Value(X)),
m_OneUse(m_ZExtOrSExt(m_Value(Y)))))) {
// Can fold trunc nsw + zext/sext for all predicates.
YIsSExt =
isa<SExtInst>(Cmp.getOperand(0)) || isa<SExtInst>(Cmp.getOperand(1));
} else
return nullptr;

Type *TruncTy = Cmp.getOperand(0)->getType();
Expand All @@ -1518,19 +1535,7 @@ InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp,
!isDesirableIntType(X->getType()->getScalarSizeInBits()))
return nullptr;

// Check if the trunc is unneeded.
KnownBits KnownX = llvm::computeKnownBits(X, /*Depth*/ 0, Q);
if (KnownX.countMaxActiveBits() > TruncBits)
return nullptr;

if (!YIsZext) {
// If Y is also a trunc, make sure it is unneeded.
KnownBits KnownY = llvm::computeKnownBits(Y, /*Depth*/ 0, Q);
if (KnownY.countMaxActiveBits() > TruncBits)
return nullptr;
}

Value *NewY = Builder.CreateZExtOrTrunc(Y, X->getType());
Value *NewY = Builder.CreateIntCast(Y, X->getType(), YIsSExt);
return new ICmpInst(Pred, X, NewY);
}

Expand Down
77 changes: 26 additions & 51 deletions llvm/test/Transforms/InstCombine/icmp-of-trunc-ext.ll
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,7 @@ define i1 @icmp_trunc_x_zext_y_fail_multiuse(i32 %x, i8 %y) {

define i1 @trunc_unsigned_nuw(i16 %x, i16 %y) {
; CHECK-LABEL: @trunc_unsigned_nuw(
; CHECK-NEXT: [[XT:%.*]] = trunc nuw i16 [[X:%.*]] to i8
; CHECK-NEXT: [[YT:%.*]] = trunc nuw i16 [[Y:%.*]] to i8
; CHECK-NEXT: [[C:%.*]] = icmp ult i8 [[XT]], [[YT]]
; CHECK-NEXT: [[C:%.*]] = icmp ult i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nuw i16 %x to i8
Expand All @@ -284,9 +282,7 @@ define i1 @trunc_unsigned_nuw(i16 %x, i16 %y) {

define i1 @trunc_unsigned_nsw(i16 %x, i16 %y) {
; CHECK-LABEL: @trunc_unsigned_nsw(
; CHECK-NEXT: [[XT:%.*]] = trunc nsw i16 [[X:%.*]] to i8
; CHECK-NEXT: [[YT:%.*]] = trunc nsw i16 [[Y:%.*]] to i8
; CHECK-NEXT: [[C:%.*]] = icmp ult i8 [[XT]], [[YT]]
; CHECK-NEXT: [[C:%.*]] = icmp ult i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nsw i16 %x to i8
Expand All @@ -297,9 +293,7 @@ define i1 @trunc_unsigned_nsw(i16 %x, i16 %y) {

define i1 @trunc_unsigned_both(i16 %x, i16 %y) {
; CHECK-LABEL: @trunc_unsigned_both(
; CHECK-NEXT: [[XT:%.*]] = trunc nuw nsw i16 [[X:%.*]] to i8
; CHECK-NEXT: [[YT:%.*]] = trunc nuw nsw i16 [[Y:%.*]] to i8
; CHECK-NEXT: [[C:%.*]] = icmp ult i8 [[XT]], [[YT]]
; CHECK-NEXT: [[C:%.*]] = icmp ult i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nuw nsw i16 %x to i8
Expand Down Expand Up @@ -336,9 +330,7 @@ define i1 @trunc_signed_nuw(i16 %x, i16 %y) {

define i1 @trunc_signed_nsw(i16 %x, i16 %y) {
; CHECK-LABEL: @trunc_signed_nsw(
; CHECK-NEXT: [[XT:%.*]] = trunc nsw i16 [[X:%.*]] to i8
; CHECK-NEXT: [[YT:%.*]] = trunc nsw i16 [[Y:%.*]] to i8
; CHECK-NEXT: [[C:%.*]] = icmp slt i8 [[XT]], [[YT]]
; CHECK-NEXT: [[C:%.*]] = icmp slt i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nsw i16 %x to i8
Expand All @@ -349,9 +341,7 @@ define i1 @trunc_signed_nsw(i16 %x, i16 %y) {

define i1 @trunc_signed_both(i16 %x, i16 %y) {
; CHECK-LABEL: @trunc_signed_both(
; CHECK-NEXT: [[XT:%.*]] = trunc nuw nsw i16 [[X:%.*]] to i8
; CHECK-NEXT: [[YT:%.*]] = trunc nuw nsw i16 [[Y:%.*]] to i8
; CHECK-NEXT: [[C:%.*]] = icmp slt i8 [[XT]], [[YT]]
; CHECK-NEXT: [[C:%.*]] = icmp slt i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nuw nsw i16 %x to i8
Expand All @@ -375,9 +365,7 @@ define i1 @trunc_signed_either(i16 %x, i16 %y) {

define i1 @trunc_equality_nuw(i16 %x, i16 %y) {
; CHECK-LABEL: @trunc_equality_nuw(
; CHECK-NEXT: [[XT:%.*]] = trunc nuw i16 [[X:%.*]] to i8
; CHECK-NEXT: [[YT:%.*]] = trunc nuw i16 [[Y:%.*]] to i8
; CHECK-NEXT: [[C:%.*]] = icmp eq i8 [[XT]], [[YT]]
; CHECK-NEXT: [[C:%.*]] = icmp eq i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nuw i16 %x to i8
Expand All @@ -388,9 +376,7 @@ define i1 @trunc_equality_nuw(i16 %x, i16 %y) {

define i1 @trunc_equality_nsw(i16 %x, i16 %y) {
; CHECK-LABEL: @trunc_equality_nsw(
; CHECK-NEXT: [[XT:%.*]] = trunc nsw i16 [[X:%.*]] to i8
; CHECK-NEXT: [[YT:%.*]] = trunc nsw i16 [[Y:%.*]] to i8
; CHECK-NEXT: [[C:%.*]] = icmp eq i8 [[XT]], [[YT]]
; CHECK-NEXT: [[C:%.*]] = icmp eq i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nsw i16 %x to i8
Expand All @@ -401,9 +387,7 @@ define i1 @trunc_equality_nsw(i16 %x, i16 %y) {

define i1 @trunc_equality_both(i16 %x, i16 %y) {
; CHECK-LABEL: @trunc_equality_both(
; CHECK-NEXT: [[XT:%.*]] = trunc nuw nsw i16 [[X:%.*]] to i8
; CHECK-NEXT: [[YT:%.*]] = trunc nuw nsw i16 [[Y:%.*]] to i8
; CHECK-NEXT: [[C:%.*]] = icmp eq i8 [[XT]], [[YT]]
; CHECK-NEXT: [[C:%.*]] = icmp eq i16 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nuw nsw i16 %x to i8
Expand All @@ -427,9 +411,8 @@ define i1 @trunc_equality_either(i16 %x, i16 %y) {

define i1 @trunc_unsigned_nuw_zext(i32 %x, i8 %y) {
; CHECK-LABEL: @trunc_unsigned_nuw_zext(
; CHECK-NEXT: [[XT:%.*]] = trunc nuw i32 [[X:%.*]] to i16
; CHECK-NEXT: [[YE:%.*]] = zext i8 [[Y:%.*]] to i16
; CHECK-NEXT: [[C:%.*]] = icmp ult i16 [[XT]], [[YE]]
; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[Y:%.*]] to i32
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nuw i32 %x to i16
Expand All @@ -453,9 +436,8 @@ define i1 @trunc_unsigned_nuw_sext(i32 %x, i8 %y) {

define i1 @trunc_unsigned_nsw_zext(i32 %x, i8 %y) {
; CHECK-LABEL: @trunc_unsigned_nsw_zext(
; CHECK-NEXT: [[XT:%.*]] = trunc nsw i32 [[X:%.*]] to i16
; CHECK-NEXT: [[YE:%.*]] = zext i8 [[Y:%.*]] to i16
; CHECK-NEXT: [[C:%.*]] = icmp ult i16 [[XT]], [[YE]]
; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[Y:%.*]] to i32
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nsw i32 %x to i16
Expand All @@ -466,9 +448,8 @@ define i1 @trunc_unsigned_nsw_zext(i32 %x, i8 %y) {

define i1 @trunc_unsigned_nsw_sext(i32 %x, i8 %y) {
; CHECK-LABEL: @trunc_unsigned_nsw_sext(
; CHECK-NEXT: [[XT:%.*]] = trunc nsw i32 [[X:%.*]] to i16
; CHECK-NEXT: [[YE:%.*]] = sext i8 [[Y:%.*]] to i16
; CHECK-NEXT: [[C:%.*]] = icmp ult i16 [[XT]], [[YE]]
; CHECK-NEXT: [[TMP1:%.*]] = sext i8 [[Y:%.*]] to i32
; CHECK-NEXT: [[C:%.*]] = icmp ugt i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nsw i32 %x to i16
Expand All @@ -479,9 +460,8 @@ define i1 @trunc_unsigned_nsw_sext(i32 %x, i8 %y) {

define i1 @trunc_signed_nsw_sext(i32 %x, i8 %y) {
; CHECK-LABEL: @trunc_signed_nsw_sext(
; CHECK-NEXT: [[XT:%.*]] = trunc nsw i32 [[X:%.*]] to i16
; CHECK-NEXT: [[YE:%.*]] = sext i8 [[Y:%.*]] to i16
; CHECK-NEXT: [[C:%.*]] = icmp slt i16 [[XT]], [[YE]]
; CHECK-NEXT: [[TMP1:%.*]] = sext i8 [[Y:%.*]] to i32
; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nsw i32 %x to i16
Expand All @@ -492,9 +472,8 @@ define i1 @trunc_signed_nsw_sext(i32 %x, i8 %y) {

define i1 @trunc_signed_nsw_zext(i32 %x, i8 %y) {
; CHECK-LABEL: @trunc_signed_nsw_zext(
; CHECK-NEXT: [[XT:%.*]] = trunc nsw i32 [[X:%.*]] to i16
; CHECK-NEXT: [[YE:%.*]] = zext i8 [[Y:%.*]] to i16
; CHECK-NEXT: [[C:%.*]] = icmp slt i16 [[XT]], [[YE]]
; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[Y:%.*]] to i32
; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nsw i32 %x to i16
Expand Down Expand Up @@ -531,9 +510,8 @@ define i1 @trunc_signed_nuw_zext(i32 %x, i8 %y) {

define i1 @trunc_equality_nuw_zext(i32 %x, i8 %y) {
; CHECK-LABEL: @trunc_equality_nuw_zext(
; CHECK-NEXT: [[XT:%.*]] = trunc nuw i32 [[X:%.*]] to i16
; CHECK-NEXT: [[YE:%.*]] = zext i8 [[Y:%.*]] to i16
; CHECK-NEXT: [[C:%.*]] = icmp ne i16 [[XT]], [[YE]]
; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[Y:%.*]] to i32
; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nuw i32 %x to i16
Expand All @@ -557,9 +535,8 @@ define i1 @trunc_equality_nuw_sext(i32 %x, i8 %y) {

define i1 @trunc_equality_nsw_zext(i32 %x, i8 %y) {
; CHECK-LABEL: @trunc_equality_nsw_zext(
; CHECK-NEXT: [[XT:%.*]] = trunc nsw i32 [[X:%.*]] to i16
; CHECK-NEXT: [[YE:%.*]] = zext i8 [[Y:%.*]] to i16
; CHECK-NEXT: [[C:%.*]] = icmp ne i16 [[XT]], [[YE]]
; CHECK-NEXT: [[TMP1:%.*]] = zext i8 [[Y:%.*]] to i32
; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nsw i32 %x to i16
Expand All @@ -570,9 +547,8 @@ define i1 @trunc_equality_nsw_zext(i32 %x, i8 %y) {

define i1 @trunc_equality_nsw_sext(i32 %x, i8 %y) {
; CHECK-LABEL: @trunc_equality_nsw_sext(
; CHECK-NEXT: [[XT:%.*]] = trunc nsw i32 [[X:%.*]] to i16
; CHECK-NEXT: [[YE:%.*]] = sext i8 [[Y:%.*]] to i16
; CHECK-NEXT: [[C:%.*]] = icmp ne i16 [[XT]], [[YE]]
; CHECK-NEXT: [[TMP1:%.*]] = sext i8 [[Y:%.*]] to i32
; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nsw i32 %x to i16
Expand All @@ -583,9 +559,8 @@ define i1 @trunc_equality_nsw_sext(i32 %x, i8 %y) {

define i1 @trunc_equality_both_sext(i32 %x, i8 %y) {
; CHECK-LABEL: @trunc_equality_both_sext(
; CHECK-NEXT: [[XT:%.*]] = trunc nuw nsw i32 [[X:%.*]] to i16
; CHECK-NEXT: [[YE:%.*]] = sext i8 [[Y:%.*]] to i16
; CHECK-NEXT: [[C:%.*]] = icmp ne i16 [[XT]], [[YE]]
; CHECK-NEXT: [[TMP1:%.*]] = sext i8 [[Y:%.*]] to i32
; CHECK-NEXT: [[C:%.*]] = icmp ne i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i1 [[C]]
;
%xt = trunc nuw nsw i32 %x to i16
Expand Down

0 comments on commit b0eeacb

Please sign in to comment.