Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Generalize icmp (shl nuw C2, Y), C -> icmp Y, C3 #104696

Merged
merged 6 commits into from
Sep 18, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Aug 18, 2024

The motivation of this patch is to fold more generalized patterns like icmp ult (shl nuw 16, X), 64 -> icmp ult X, 2.

Alive2: https://alive2.llvm.org/ce/z/gyqjQH

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 18, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

Changes

The motivation of this patch is to fold more generalized patterns like icmp ult (shl nuw 16, X), 64 -> icmp ult X, 2.

Alive2: https://alive2.llvm.org/ce/z/0xcdE6


Full diff: https://github.com/llvm/llvm-project/pull/104696.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+14-8)
  • (modified) llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll (+67)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 34c9e0fde4f428..8034befb6b9449 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2227,18 +2227,24 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp,
   return NewC ? new ICmpInst(Pred, X, NewC) : nullptr;
 }
 
-/// Fold icmp (shl 1, Y), C.
-static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
-                                   const APInt &C) {
+/// Fold icmp (shl nuw C2, Y), C.
+static Instruction *foldICmpShlLHSC(ICmpInst &Cmp, Instruction *Shl,
+                                    const APInt &C) {
   Value *Y;
-  if (!match(Shl, m_Shl(m_One(), m_Value(Y))))
+  const APInt *C2;
+  if (!match(Shl, m_NUWShl(m_APInt(C2), m_Value(Y))))
     return nullptr;
 
   Type *ShiftType = Shl->getType();
   unsigned TypeBits = C.getBitWidth();
-  bool CIsPowerOf2 = C.isPowerOf2();
   ICmpInst::Predicate Pred = Cmp.getPredicate();
   if (Cmp.isUnsigned()) {
+    APInt Div, Rem;
+    APInt::udivrem(C, *C2, Div, Rem);
+    if (!Rem.isZero())
+      return nullptr;
+    bool CIsPowerOf2 = Div.isPowerOf2();
+
     // (1 << Y) pred C -> Y pred Log2(C)
     if (!CIsPowerOf2) {
       // (1 << Y) <  30 -> Y <= 4
@@ -2251,9 +2257,9 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
         Pred = ICmpInst::ICMP_UGT;
     }
 
-    unsigned CLog2 = C.logBase2();
+    unsigned CLog2 = Div.logBase2();
     return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2));
-  } else if (Cmp.isSigned()) {
+  } else if (Cmp.isSigned() && C2->isOne()) {
     Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1);
     // (1 << Y) >  0 -> Y != 31
     // (1 << Y) >  C -> Y != 31 if C is negative.
@@ -2307,7 +2313,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstant(ICmpInst &Cmp,
 
   const APInt *ShiftAmt;
   if (!match(Shl->getOperand(1), m_APInt(ShiftAmt)))
-    return foldICmpShlOne(Cmp, Shl, C);
+    return foldICmpShlLHSC(Cmp, Shl, C);
 
   // Check that the shift amount is in range. If not, don't perform undefined
   // shifts. When the shift is visited, it will be simplified.
diff --git a/llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll b/llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll
index 57c3abc7b9841f..46671f83610fd1 100644
--- a/llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll
@@ -90,3 +90,70 @@ define <2 x i1> @icmp_ugt_16x2(<2 x i32>) {
   %d = icmp ugt <2 x i32> %c, <i32 1048575, i32 1048575>
   ret <2 x i1> %d
 }
+
+define i1 @fold_icmp_shl_nuw_c1(i32 %x) {
+; CHECK-LABEL: @fold_icmp_shl_nuw_c1(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i32 [[X:%.*]], 61440
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[TMP1]], 0
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %lshr = lshr i32 %x, 12
+  %and = and i32 %lshr, 15
+  %shl = shl nuw i32 2, %and
+  %cmp = icmp ult i32 %shl, 4
+  ret i1 %cmp
+}
+
+define i1 @fold_icmp_shl_nuw_c2(i32 %x) {
+; CHECK-LABEL: @fold_icmp_shl_nuw_c2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 2
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl nuw i32 16, %x
+  %cmp = icmp ult i32 %shl, 64
+  ret i1 %cmp
+}
+
+define i1 @fold_icmp_shl_nuw_c2_non_pow2(i32 %x) {
+; CHECK-LABEL: @fold_icmp_shl_nuw_c2_non_pow2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 2
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl nuw i32 48, %x
+  %cmp = icmp ult i32 %shl, 192
+  ret i1 %cmp
+}
+
+define i1 @fold_icmp_shl_nuw_c2_div_non_pow2(i32 %x) {
+; CHECK-LABEL: @fold_icmp_shl_nuw_c2_div_non_pow2(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[X:%.*]], 5
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl nuw i32 2, %x
+  %cmp = icmp ult i32 %shl, 60
+  ret i1 %cmp
+}
+
+; Negative tests
+
+define i1 @fold_icmp_shl_nuw_c2_indivisible(i32 %x) {
+; CHECK-LABEL: @fold_icmp_shl_nuw_c2_indivisible(
+; CHECK-NEXT:    [[SHL:%.*]] = shl nuw i32 16, [[X:%.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[SHL]], 63
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl nuw i32 16, %x
+  %cmp = icmp ult i32 %shl, 63
+  ret i1 %cmp
+}
+
+define i1 @fold_icmp_shl_c2_without_nuw(i32 %x) {
+; CHECK-LABEL: @fold_icmp_shl_c2_without_nuw(
+; CHECK-NEXT:    [[SHL:%.*]] = shl i32 16, [[X:%.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i32 [[SHL]], 64
+; CHECK-NEXT:    ret i1 [[CMP]]
+;
+  %shl = shl i32 16, %x
+  %cmp = icmp ult i32 %shl, 64
+  ret i1 %cmp
+}

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Aug 18, 2024
APInt Div, Rem;
APInt::udivrem(C, *C2, Div, Rem);
if (!Rem.isZero())
return nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be generalized to the non-rem-0 case? Taking your example, icmp ult (shl nuw C2, X), 64 -> icmp ult X, 2 holds for any C2 in [16,31].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All non-rem-0 cases:

Input files: 37900
Progress: 3183icmp ugt (shl nuw 12, X), 72057594037927935
icmp ugt (shl nuw 12, X), 72057594037927935
Progress: 3189icmp ugt (shl nuw 12, X), 72057594037927935
Progress: 3279icmp ugt (shl nuw 12, X), 72057594037927935
Progress: 3308icmp ugt (shl nuw 12, X), 72057594037927935
icmp ugt (shl nuw 12, X), 72057594037927935
Progress: 6875icmp ugt (shl nuw 4, X), 31
Progress: 16038icmp ugt (shl nuw 2, X), 134217727
icmp ugt (shl nuw 2, X), 134217727
icmp ugt (shl nuw 2, X), 134217727
icmp ugt (shl nuw 2, X), 134217727
icmp ugt (shl nuw 2, X), 134217727
icmp ugt (shl nuw 2, X), 134217727
icmp ugt (shl nuw 2, X), 134217727
icmp ugt (shl nuw 2, X), 134217727
Progress: 18015icmp ugt (shl nuw 16, X), 63
Progress: 24392icmp ugt (shl nuw 2, X), 3
Progress: 37893
8
c3c/optimized/target.c.ll
folly/optimized/DynamicParser.cpp.ll
folly/optimized/LogConfigParser.cpp.ll
folly/optimized/dynamic.cpp.ll
folly/optimized/json.cpp.ll
linux/optimized/xhci-pci.ll
regex-rs/optimized/11vfjke4utuj478u.ll
wasmtime-rs/optimized/2ly4gzztxx8hlwxv.ll

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be generalized to the non-rem-0 case? Taking your example, icmp ult (shl nuw C2, X), 64 -> icmp ult X, 2 holds for any C2 in [16,31].

Do you have any ideas to handle these non-rem-0 cases? If not, I will convert icmp ugt (shl nuw C2, X), C3 -> icmp uge (shl nuw C2, X), C3 + 1 to cover more real-world cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it just find the first X value where C2 nuw<< X u< C1 is false? Which is just log2(roundup_next_p2(C1 / C2)) no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thank you!

@nikic
Copy link
Contributor

nikic commented Aug 18, 2024

Alive2: https://alive2.llvm.org/ce/z/0xcdE6

These proofs don't match the implementation. Note that the pow2 check is not a pre-condition of the transform, it is only used to adjust the predicate in some cases. There is no adjustment for ule predicates, so in that case I'd expect this to verify for the implementation to be correct: https://alive2.llvm.org/ce/z/wk4LYe

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Sep 10, 2024

Alive2: https://alive2.llvm.org/ce/z/0xcdE6

These proofs don't match the implementation. Note that the pow2 check is not a pre-condition of the transform, it is only used to adjust the predicate in some cases. There is no adjustment for ule predicates, so in that case I'd expect this to verify for the implementation to be correct: https://alive2.llvm.org/ce/z/wk4LYe

Updated.

@dtcxzyw
Copy link
Member Author

dtcxzyw commented Sep 17, 2024

Ping.

@@ -2251,9 +2257,9 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
Pred = ICmpInst::ICMP_UGT;
}

unsigned CLog2 = C.logBase2();
unsigned CLog2 = Div.logBase2();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test where C and C2 are not a power of two and we need to roundup the log? I.e
(48 << X) u>= 144?

I don't see the ceil logic for that case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@goldsteinn goldsteinn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM give nikic a day before pushing.

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/icmp-shl-nuw.ll Outdated Show resolved Hide resolved
@dtcxzyw dtcxzyw merged commit 872932b into llvm:main Sep 18, 2024
8 checks passed
@dtcxzyw dtcxzyw deleted the perf/fold-icmp-shl-nuw-c branch September 18, 2024 11:10
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
…m#104696)

The motivation of this patch is to fold more generalized patterns like
`icmp ult (shl nuw 16, X), 64 -> icmp ult X, 2`.

Alive2: https://alive2.llvm.org/ce/z/gyqjQH
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants