Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstSimplify] Fold X * (2^N + 1) >> N -> X when N is half the bitwidth of X #92909

Closed
wants to merge 1 commit into from

Conversation

AZero13
Copy link
Contributor

@AZero13 AZero13 commented May 21, 2024

No description provided.

@AZero13 AZero13 requested a review from nikic as a code owner May 21, 2024 12:45
@AZero13 AZero13 changed the title [Do not merge until #92907 is merged first] Fold X * (2^N + 1) >> N -> X when N is half the bitwidth of X May 21, 2024
@llvmbot
Copy link
Member

llvmbot commented May 21, 2024

@llvm/pr-subscribers-llvm-analysis

Author: AtariDreams (AtariDreams)

Changes

This depends on #92907 being merged first. Once that is done, I will update the title and description of this.


Full diff: https://github.com/llvm/llvm-project/pull/92909.diff

4 Files Affected:

  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+39)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+47-20)
  • (modified) llvm/test/Transforms/InstCombine/ashr-lshr.ll (+281)
  • (modified) llvm/test/Transforms/InstCombine/lshr.ll (+28-8)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 53a974c5294c6..15f71dcb87620 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1479,6 +1479,29 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
   if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1))))
     return X;
 
+  // Look for a "splat" mul pattern - it replicates bits across each half
+  // of a value, so a right shift is just a mask of the low bits:
+  const APInt *MulC;
+  const APInt *ShAmt;
+  if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) &&
+      match(Op1, m_APInt(ShAmt))) {
+    unsigned ShAmtC = ShAmt->getZExtValue();
+    unsigned BitWidth = ShAmt->getBitWidth();
+    if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+        MulC->logBase2() == ShAmtC) {
+      // FIXME: This condition should be covered by the computeKnownBits, but
+      // for some reason it is not, so keep this in for now. This has no
+      // negative effects, but KnownBits should be able to infer a number of
+      // leading bits based on 2^N + 1 not wrapping, as that means 2^N must not
+      // wrap either, which means the top N bits of X must be 0.
+      if (ShAmtC * 2 == BitWidth)
+        return X;
+      const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
+      if (XKnown.countMaxActiveBits() <= ShAmtC)
+        return X;
+    }
+  }
+
   // ((X << A) | Y) >> A -> X  if effective width of Y is not larger than A.
   // We can return X as we do in the above case since OR alters no bits in X.
   // SimplifyDemandedBits in InstCombine can do more general optimization for
@@ -1523,6 +1546,22 @@ static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact,
   if (Q.IIQ.UseInstrInfo && match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1))))
     return X;
 
+  const APInt *MulC;
+  const APInt *ShAmt;
+  if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) &&
+      match(Op1, m_APInt(ShAmt)) &&
+      cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()) {
+    unsigned ShAmtC = ShAmt->getZExtValue();
+    unsigned BitWidth = ShAmt->getBitWidth();
+    if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+        MulC->logBase2() == ShAmtC &&
+        ShAmtC < BitWidth - 1) /* Minus 1 for the sign bit */ {
+      KnownBits KnownX = computeKnownBits(X, /* Depth */ 0, Q);
+      if (KnownX.countMaxActiveBits() <= ShAmtC)
+        return X;
+    }
+  }
+
   // Arithmetic shifting an all-sign-bit value is a no-op.
   unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
   if (NumSignBits == Op0->getType()->getScalarSizeInBits())
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index ba297111d945f..8dd0f2f61756c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1456,30 +1456,42 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
     }
 
     const APInt *MulC;
-    if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) {
-      // Look for a "splat" mul pattern - it replicates bits across each half of
-      // a value, so a right shift is just a mask of the low bits:
-      // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
-      // TODO: Generalize to allow more than just half-width shifts?
-      if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() &&
-          MulC->logBase2() == ShAmtC)
-        return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2));
+    if (match(Op0, m_OneUse(m_NUWMul(m_Value(X), m_APInt(MulC))))) {
+      if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+          MulC->logBase2() == ShAmtC) {
+
+        // lshr (mul nuw (X, 2^N + 1)), N -> add nuw (X, lshr(X, N))
+        auto *NewAdd = BinaryOperator::CreateNUWAdd(
+            X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
+                                  I.isExact()));
+        NewAdd->setHasNoSignedWrap(
+            cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap());
+        return NewAdd;
+      }
 
       // The one-use check is not strictly necessary, but codegen may not be
       // able to invert the transform and perf may suffer with an extra mul
       // instruction.
-      if (Op0->hasOneUse()) {
-        APInt NewMulC = MulC->lshr(ShAmtC);
-        // if c is divisible by (1 << ShAmtC):
-        // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
-        if (MulC->eq(NewMulC.shl(ShAmtC))) {
-          auto *NewMul =
-              BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
-          assert(ShAmtC != 0 &&
-                 "lshr X, 0 should be handled by simplifyLShrInst.");
-          NewMul->setHasNoSignedWrap(true);
-          return NewMul;
-        }
+      APInt NewMulC = MulC->lshr(ShAmtC);
+      // if c is divisible by (1 << ShAmtC):
+      // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
+      if (MulC->eq(NewMulC.shl(ShAmtC))) {
+        auto *NewMul =
+            BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
+        assert(ShAmtC != 0 &&
+               "lshr X, 0 should be handled by simplifyLShrInst.");
+        NewMul->setHasNoSignedWrap(true);
+        return NewMul;
+      }
+    }
+
+    // lshr (mul nsw (X, 2^N + 1)), N -> add nsw (X, lshr(X, N))
+    if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC))))) {
+      if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+          MulC->logBase2() == ShAmtC) {
+        return BinaryOperator::CreateNSWAdd(
+            X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
+                                  I.isExact()));
       }
     }
 
@@ -1686,6 +1698,21 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
       if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
         return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
     }
+
+    const APInt *MulC;
+    if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC)))) &&
+        (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+         MulC->logBase2() == ShAmt &&
+         (ShAmt < BitWidth - 1))) /* Minus 1 for the sign bit */ {
+
+      // ashr (mul nsw (X, 2^N + 1)), N -> add nsw (X, ashr(X, N))
+      auto *NewAdd = BinaryOperator::CreateNSWAdd(
+          X,
+          Builder.CreateAShr(X, ConstantInt::get(Ty, ShAmt), "", I.isExact()));
+      NewAdd->setHasNoUnsignedWrap(
+          cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap());
+      return NewAdd;
+    }
   }
 
   const SimplifyQuery Q = SQ.getWithInstruction(&I);
diff --git a/llvm/test/Transforms/InstCombine/ashr-lshr.ll b/llvm/test/Transforms/InstCombine/ashr-lshr.ll
index ac206dc7999dd..f426755dfc9dd 100644
--- a/llvm/test/Transforms/InstCombine/ashr-lshr.ll
+++ b/llvm/test/Transforms/InstCombine/ashr-lshr.ll
@@ -604,3 +604,284 @@ define <2 x i8> @ashr_known_pos_exact_vec(<2 x i8> %x, <2 x i8> %y) {
   %r = ashr exact <2 x i8> %p, %y
   ret <2 x i8> %r
 }
+
+define i32 @lshr_mul_times_3_div_2(i32 %0) {
+; CHECK-LABEL: @lshr_mul_times_3_div_2(
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nsw nuw i32 %0, 3
+  %lshr = lshr i32 %mul, 1
+  ret i32 %lshr
+}
+
+define i32 @lshr_mul_times_3_div_2_exact(i32 %x) {
+; CHECK-LABEL: @lshr_mul_times_3_div_2_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 1
+; CHECK-NEXT:    [[LSHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nsw i32 %x, 3
+  %lshr = lshr exact i32 %mul, 1
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @lshr_mul_times_3_div_2_no_flags(i32 %0) {
+; CHECK-LABEL: @lshr_mul_times_3_div_2_no_flags(
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[TMP0:%.*]], 3
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[MUL]], 1
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul i32 %0, 3
+  %lshr = lshr i32 %mul, 1
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @mul_times_3_div_2_multiuse_lshr(i32 %x) {
+; CHECK-LABEL: @mul_times_3_div_2_multiuse_lshr(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 3
+; CHECK-NEXT:    [[RES:%.*]] = lshr i32 [[MUL]], 1
+; CHECK-NEXT:    call void @use(i32 [[MUL]])
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %mul = mul nuw i32 %x, 3
+  %res = lshr i32 %mul, 1
+  call void @use(i32 %mul)
+  ret i32 %res
+}
+
+define i32 @lshr_mul_times_3_div_2_exact_2(i32 %x) {
+; CHECK-LABEL: @lshr_mul_times_3_div_2_exact_2(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 1
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nuw i32 %x, 3
+  %lshr = lshr exact i32 %mul, 1
+  ret i32 %lshr
+}
+
+define i32 @lshr_mul_times_5_div_4(i32 %0) {
+; CHECK-LABEL: @lshr_mul_times_5_div_4(
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 2
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nsw nuw i32 %0, 5
+  %lshr = lshr i32 %mul, 2
+  ret i32 %lshr
+}
+
+define i32 @lshr_mul_times_5_div_4_exact(i32 %x) {
+; CHECK-LABEL: @lshr_mul_times_5_div_4_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[LSHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nsw i32 %x, 5
+  %lshr = lshr exact i32 %mul, 2
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @lshr_mul_times_5_div_4_no_flags(i32 %0) {
+; CHECK-LABEL: @lshr_mul_times_5_div_4_no_flags(
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[TMP0:%.*]], 5
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[MUL]], 2
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul i32 %0, 5
+  %lshr = lshr i32 %mul, 2
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @mul_times_5_div_4_multiuse_lshr(i32 %x) {
+; CHECK-LABEL: @mul_times_5_div_4_multiuse_lshr(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 5
+; CHECK-NEXT:    [[RES:%.*]] = lshr i32 [[MUL]], 2
+; CHECK-NEXT:    call void @use(i32 [[MUL]])
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %mul = mul nuw i32 %x, 5
+  %res = lshr i32 %mul, 2
+  call void @use(i32 %mul)
+  ret i32 %res
+}
+
+define i32 @lshr_mul_times_5_div_4_exact_2(i32 %x) {
+; CHECK-LABEL: @lshr_mul_times_5_div_4_exact_2(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nuw i32 %x, 5
+  %lshr = lshr exact i32 %mul, 2
+  ret i32 %lshr
+}
+
+define i32 @ashr_mul_times_3_div_2(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2(
+; CHECK-NEXT:    [[TMP2:%.*]] = ashr i32 [[TMP0:%.*]], 1
+; CHECK-NEXT:    [[ASHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nuw nsw i32 %0, 3
+  %ashr = ashr i32 %mul, 1
+  ret i32 %ashr
+}
+
+define i32 @ashr_mul_times_3_div_2_exact(i32 %x) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 1
+; CHECK-NEXT:    [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nsw i32 %x, 3
+  %ashr = ashr exact i32 %mul, 1
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @ashr_mul_times_3_div_2_no_flags(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2_no_flags(
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[TMP0:%.*]], 3
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i32 [[MUL]], 1
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul i32 %0, 3
+  %ashr = ashr i32 %mul, 1
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @ashr_mul_times_3_div_2_no_nsw(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2_no_nsw(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i32 [[TMP0:%.*]], 3
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i32 [[MUL]], 1
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nuw i32 %0, 3
+  %ashr = ashr i32 %mul, 1
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @mul_times_3_div_2_multiuse_ashr(i32 %x) {
+; CHECK-LABEL: @mul_times_3_div_2_multiuse_ashr(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 3
+; CHECK-NEXT:    [[RES:%.*]] = ashr i32 [[MUL]], 1
+; CHECK-NEXT:    call void @use(i32 [[MUL]])
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %mul = mul nsw i32 %x, 3
+  %res = ashr i32 %mul, 1
+  call void @use(i32 %mul)
+  ret i32 %res
+}
+
+define i32 @ashr_mul_times_3_div_2_exact_2(i32 %x) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2_exact_2(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 1
+; CHECK-NEXT:    [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nsw i32 %x, 3
+  %ashr = ashr exact i32 %mul, 1
+  ret i32 %ashr
+}
+
+define i32 @ashr_mul_times_5_div_4(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_5_div_4(
+; CHECK-NEXT:    [[TMP2:%.*]] = ashr i32 [[TMP0:%.*]], 2
+; CHECK-NEXT:    [[ASHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nuw nsw i32 %0, 5
+  %ashr = ashr i32 %mul, 2
+  ret i32 %ashr
+}
+
+define i32 @ashr_mul_times_5_div_4_exact(i32 %x) {
+; CHECK-LABEL: @ashr_mul_times_5_div_4_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nsw i32 %x, 5
+  %ashr = ashr exact i32 %mul, 2
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @ashr_mul_times_5_div_4_no_flags(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_5_div_4_no_flags(
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[TMP0:%.*]], 5
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i32 [[MUL]], 2
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul i32 %0, 5
+  %ashr = ashr i32 %mul, 2
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @mul_times_5_div_4_multiuse_ashr(i32 %x) {
+; CHECK-LABEL: @mul_times_5_div_4_multiuse_ashr(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 5
+; CHECK-NEXT:    [[RES:%.*]] = ashr i32 [[MUL]], 2
+; CHECK-NEXT:    call void @use(i32 [[MUL]])
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %mul = mul nsw i32 %x, 5
+  %res = ashr i32 %mul, 2
+  call void @use(i32 %mul)
+  ret i32 %res
+}
+
+define i32 @ashr_mul_times_5_div_4_exact_2(i32 %x) {
+; CHECK-LABEL: @ashr_mul_times_5_div_4_exact_2(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nsw i32 %x, 5
+  %ashr = ashr exact i32 %mul, 2
+  ret i32 %ashr
+}
+
+define i32 @mul_splat_fold_known_active_bits(i32 %x) {
+; CHECK-LABEL: @mul_splat_fold_known_active_bits(
+; CHECK-NEXT:    [[XX:%.*]] = and i32 [[X:%.*]], 360
+; CHECK-NEXT:    ret i32 [[XX]]
+;
+  %xx = and i32 %x, 360
+  %m = mul nuw i32 %xx, 65537
+  %t = ashr i32 %m, 16
+  ret i32 %t
+}
+
+define i32 @mul_splat_fold_no_known_active_bits(i32 %x) {
+; CHECK-LABEL: @mul_splat_fold_no_known_active_bits(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr i32 [[X:%.*]], 16
+; CHECK-NEXT:    [[T:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[T]]
+;
+  %m = mul nsw i32 %x, 65537
+  %t = ashr i32 %m, 16
+  ret i32 %t
+}
+
+declare void @use(i32)
diff --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll
index fa92c1c4b3be4..17b08985ee90e 100644
--- a/llvm/test/Transforms/InstCombine/lshr.ll
+++ b/llvm/test/Transforms/InstCombine/lshr.ll
@@ -348,22 +348,31 @@ define <2 x i32> @narrow_lshr_constant(<2 x i8> %x, <2 x i8> %y) {
 
 define i32 @mul_splat_fold(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold(
-; CHECK-NEXT:    [[T:%.*]] = and i32 [[X:%.*]], 65535
-; CHECK-NEXT:    ret i32 [[T]]
+; CHECK-NEXT:    ret i32 [[X:%.*]]
 ;
   %m = mul nuw i32 %x, 65537
   %t = lshr i32 %m, 16
   ret i32 %t
 }
 
+define i32 @mul_splat_fold_known_zeros(i32 %x) {
+; CHECK-LABEL: @mul_splat_fold_known_zeros(
+; CHECK-NEXT:    [[XX:%.*]] = and i32 [[X:%.*]], 360
+; CHECK-NEXT:    ret i32 [[XX]]
+;
+  %xx = and i32 %x, 360
+  %m = mul nuw i32 %xx, 65537
+  %t = lshr i32 %m, 16
+  ret i32 %t
+}
+
 ; Vector type, extra use, weird types are all ok.
 
 define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {
 ; CHECK-LABEL: @mul_splat_fold_vec(
 ; CHECK-NEXT:    [[M:%.*]] = mul nuw <3 x i14> [[X:%.*]], <i14 129, i14 129, i14 129>
 ; CHECK-NEXT:    call void @usevec(<3 x i14> [[M]])
-; CHECK-NEXT:    [[T:%.*]] = and <3 x i14> [[X]], <i14 127, i14 127, i14 127>
-; CHECK-NEXT:    ret <3 x i14> [[T]]
+; CHECK-NEXT:    ret <3 x i14> [[X]]
 ;
   %m = mul nuw <3 x i14> %x, <i14 129, i14 129, i14 129>
   call void @usevec(<3 x i14> %m)
@@ -628,12 +637,10 @@ define i32 @mul_splat_fold_wrong_lshr_const(i32 %x) {
   ret i32 %t
 }
 
-; Negative test
-
 define i32 @mul_splat_fold_no_nuw(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold_no_nuw(
-; CHECK-NEXT:    [[M:%.*]] = mul nsw i32 [[X:%.*]], 65537
-; CHECK-NEXT:    [[T:%.*]] = lshr i32 [[M]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[X:%.*]], 16
+; CHECK-NEXT:    [[T:%.*]] = add nsw i32 [[TMP1]], [[X]]
 ; CHECK-NEXT:    ret i32 [[T]]
 ;
   %m = mul nsw i32 %x, 65537
@@ -641,6 +648,19 @@ define i32 @mul_splat_fold_no_nuw(i32 %x) {
   ret i32 %t
 }
 
+; Negative test
+
+define i32 @mul_splat_fold_no_flags(i32 %x) {
+; CHECK-LABEL: @mul_splat_fold_no_flags(
+; CHECK-NEXT:    [[M:%.*]] = mul i32 [[X:%.*]], 65537
+; CHECK-NEXT:    [[T:%.*]] = lshr i32 [[M]], 16
+; CHECK-NEXT:    ret i32 [[T]]
+;
+  %m = mul i32 %x, 65537
+  %t = lshr i32 %m, 16
+  ret i32 %t
+}
+
 ; Negative test (but simplifies before we reach the mul_splat transform)- need more than 2 bits
 
 define i2 @mul_splat_fold_too_narrow(i2 %x) {

@llvmbot
Copy link
Member

llvmbot commented May 21, 2024

@llvm/pr-subscribers-llvm-transforms

Author: AtariDreams (AtariDreams)

Changes

This depends on #92907 being merged first. Once that is done, I will update the title and description of this.


Full diff: https://github.com/llvm/llvm-project/pull/92909.diff

4 Files Affected:

  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+39)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+47-20)
  • (modified) llvm/test/Transforms/InstCombine/ashr-lshr.ll (+281)
  • (modified) llvm/test/Transforms/InstCombine/lshr.ll (+28-8)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 53a974c5294c6..15f71dcb87620 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1479,6 +1479,29 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
   if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1))))
     return X;
 
+  // Look for a "splat" mul pattern - it replicates bits across each half
+  // of a value, so a right shift is just a mask of the low bits:
+  const APInt *MulC;
+  const APInt *ShAmt;
+  if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) &&
+      match(Op1, m_APInt(ShAmt))) {
+    unsigned ShAmtC = ShAmt->getZExtValue();
+    unsigned BitWidth = ShAmt->getBitWidth();
+    if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+        MulC->logBase2() == ShAmtC) {
+      // FIXME: This condition should be covered by the computeKnownBits, but
+      // for some reason it is not, so keep this in for now. This has no
+      // negative effects, but KnownBits should be able to infer a number of
+      // leading bits based on 2^N + 1 not wrapping, as that means 2^N must not
+      // wrap either, which means the top N bits of X must be 0.
+      if (ShAmtC * 2 == BitWidth)
+        return X;
+      const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
+      if (XKnown.countMaxActiveBits() <= ShAmtC)
+        return X;
+    }
+  }
+
   // ((X << A) | Y) >> A -> X  if effective width of Y is not larger than A.
   // We can return X as we do in the above case since OR alters no bits in X.
   // SimplifyDemandedBits in InstCombine can do more general optimization for
@@ -1523,6 +1546,22 @@ static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact,
   if (Q.IIQ.UseInstrInfo && match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1))))
     return X;
 
+  const APInt *MulC;
+  const APInt *ShAmt;
+  if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) &&
+      match(Op1, m_APInt(ShAmt)) &&
+      cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()) {
+    unsigned ShAmtC = ShAmt->getZExtValue();
+    unsigned BitWidth = ShAmt->getBitWidth();
+    if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+        MulC->logBase2() == ShAmtC &&
+        ShAmtC < BitWidth - 1) /* Minus 1 for the sign bit */ {
+      KnownBits KnownX = computeKnownBits(X, /* Depth */ 0, Q);
+      if (KnownX.countMaxActiveBits() <= ShAmtC)
+        return X;
+    }
+  }
+
   // Arithmetic shifting an all-sign-bit value is a no-op.
   unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
   if (NumSignBits == Op0->getType()->getScalarSizeInBits())
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index ba297111d945f..8dd0f2f61756c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1456,30 +1456,42 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
     }
 
     const APInt *MulC;
-    if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC)))) {
-      // Look for a "splat" mul pattern - it replicates bits across each half of
-      // a value, so a right shift is just a mask of the low bits:
-      // lshr i[2N] (mul nuw X, (2^N)+1), N --> and iN X, (2^N)-1
-      // TODO: Generalize to allow more than just half-width shifts?
-      if (BitWidth > 2 && ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() &&
-          MulC->logBase2() == ShAmtC)
-        return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2));
+    if (match(Op0, m_OneUse(m_NUWMul(m_Value(X), m_APInt(MulC))))) {
+      if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+          MulC->logBase2() == ShAmtC) {
+
+        // lshr (mul nuw (X, 2^N + 1)), N -> add nuw (X, lshr(X, N))
+        auto *NewAdd = BinaryOperator::CreateNUWAdd(
+            X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
+                                  I.isExact()));
+        NewAdd->setHasNoSignedWrap(
+            cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap());
+        return NewAdd;
+      }
 
       // The one-use check is not strictly necessary, but codegen may not be
       // able to invert the transform and perf may suffer with an extra mul
       // instruction.
-      if (Op0->hasOneUse()) {
-        APInt NewMulC = MulC->lshr(ShAmtC);
-        // if c is divisible by (1 << ShAmtC):
-        // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
-        if (MulC->eq(NewMulC.shl(ShAmtC))) {
-          auto *NewMul =
-              BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
-          assert(ShAmtC != 0 &&
-                 "lshr X, 0 should be handled by simplifyLShrInst.");
-          NewMul->setHasNoSignedWrap(true);
-          return NewMul;
-        }
+      APInt NewMulC = MulC->lshr(ShAmtC);
+      // if c is divisible by (1 << ShAmtC):
+      // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC)
+      if (MulC->eq(NewMulC.shl(ShAmtC))) {
+        auto *NewMul =
+            BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC));
+        assert(ShAmtC != 0 &&
+               "lshr X, 0 should be handled by simplifyLShrInst.");
+        NewMul->setHasNoSignedWrap(true);
+        return NewMul;
+      }
+    }
+
+    // lshr (mul nsw (X, 2^N + 1)), N -> add nsw (X, lshr(X, N))
+    if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC))))) {
+      if (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+          MulC->logBase2() == ShAmtC) {
+        return BinaryOperator::CreateNSWAdd(
+            X, Builder.CreateLShr(X, ConstantInt::get(Ty, ShAmtC), "",
+                                  I.isExact()));
       }
     }
 
@@ -1686,6 +1698,21 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
       if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
         return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
     }
+
+    const APInt *MulC;
+    if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_APInt(MulC)))) &&
+        (BitWidth > 2 && (*MulC - 1).isPowerOf2() &&
+         MulC->logBase2() == ShAmt &&
+         (ShAmt < BitWidth - 1))) /* Minus 1 for the sign bit */ {
+
+      // ashr (mul nsw (X, 2^N + 1)), N -> add nsw (X, ashr(X, N))
+      auto *NewAdd = BinaryOperator::CreateNSWAdd(
+          X,
+          Builder.CreateAShr(X, ConstantInt::get(Ty, ShAmt), "", I.isExact()));
+      NewAdd->setHasNoUnsignedWrap(
+          cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap());
+      return NewAdd;
+    }
   }
 
   const SimplifyQuery Q = SQ.getWithInstruction(&I);
diff --git a/llvm/test/Transforms/InstCombine/ashr-lshr.ll b/llvm/test/Transforms/InstCombine/ashr-lshr.ll
index ac206dc7999dd..f426755dfc9dd 100644
--- a/llvm/test/Transforms/InstCombine/ashr-lshr.ll
+++ b/llvm/test/Transforms/InstCombine/ashr-lshr.ll
@@ -604,3 +604,284 @@ define <2 x i8> @ashr_known_pos_exact_vec(<2 x i8> %x, <2 x i8> %y) {
   %r = ashr exact <2 x i8> %p, %y
   ret <2 x i8> %r
 }
+
+define i32 @lshr_mul_times_3_div_2(i32 %0) {
+; CHECK-LABEL: @lshr_mul_times_3_div_2(
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nsw nuw i32 %0, 3
+  %lshr = lshr i32 %mul, 1
+  ret i32 %lshr
+}
+
+define i32 @lshr_mul_times_3_div_2_exact(i32 %x) {
+; CHECK-LABEL: @lshr_mul_times_3_div_2_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 1
+; CHECK-NEXT:    [[LSHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nsw i32 %x, 3
+  %lshr = lshr exact i32 %mul, 1
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @lshr_mul_times_3_div_2_no_flags(i32 %0) {
+; CHECK-LABEL: @lshr_mul_times_3_div_2_no_flags(
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[TMP0:%.*]], 3
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[MUL]], 1
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul i32 %0, 3
+  %lshr = lshr i32 %mul, 1
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @mul_times_3_div_2_multiuse_lshr(i32 %x) {
+; CHECK-LABEL: @mul_times_3_div_2_multiuse_lshr(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 3
+; CHECK-NEXT:    [[RES:%.*]] = lshr i32 [[MUL]], 1
+; CHECK-NEXT:    call void @use(i32 [[MUL]])
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %mul = mul nuw i32 %x, 3
+  %res = lshr i32 %mul, 1
+  call void @use(i32 %mul)
+  ret i32 %res
+}
+
+define i32 @lshr_mul_times_3_div_2_exact_2(i32 %x) {
+; CHECK-LABEL: @lshr_mul_times_3_div_2_exact_2(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 1
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nuw i32 %x, 3
+  %lshr = lshr exact i32 %mul, 1
+  ret i32 %lshr
+}
+
+define i32 @lshr_mul_times_5_div_4(i32 %0) {
+; CHECK-LABEL: @lshr_mul_times_5_div_4(
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 2
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nsw nuw i32 %0, 5
+  %lshr = lshr i32 %mul, 2
+  ret i32 %lshr
+}
+
+define i32 @lshr_mul_times_5_div_4_exact(i32 %x) {
+; CHECK-LABEL: @lshr_mul_times_5_div_4_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[LSHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nsw i32 %x, 5
+  %lshr = lshr exact i32 %mul, 2
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @lshr_mul_times_5_div_4_no_flags(i32 %0) {
+; CHECK-LABEL: @lshr_mul_times_5_div_4_no_flags(
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[TMP0:%.*]], 5
+; CHECK-NEXT:    [[LSHR:%.*]] = lshr i32 [[MUL]], 2
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul i32 %0, 5
+  %lshr = lshr i32 %mul, 2
+  ret i32 %lshr
+}
+
+; Negative test
+
+define i32 @mul_times_5_div_4_multiuse_lshr(i32 %x) {
+; CHECK-LABEL: @mul_times_5_div_4_multiuse_lshr(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i32 [[X:%.*]], 5
+; CHECK-NEXT:    [[RES:%.*]] = lshr i32 [[MUL]], 2
+; CHECK-NEXT:    call void @use(i32 [[MUL]])
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %mul = mul nuw i32 %x, 5
+  %res = lshr i32 %mul, 2
+  call void @use(i32 %mul)
+  ret i32 %res
+}
+
+define i32 @lshr_mul_times_5_div_4_exact_2(i32 %x) {
+; CHECK-LABEL: @lshr_mul_times_5_div_4_exact_2(
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr exact i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[LSHR:%.*]] = add nuw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[LSHR]]
+;
+  %mul = mul nuw i32 %x, 5
+  %lshr = lshr exact i32 %mul, 2
+  ret i32 %lshr
+}
+
+define i32 @ashr_mul_times_3_div_2(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2(
+; CHECK-NEXT:    [[TMP2:%.*]] = ashr i32 [[TMP0:%.*]], 1
+; CHECK-NEXT:    [[ASHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nuw nsw i32 %0, 3
+  %ashr = ashr i32 %mul, 1
+  ret i32 %ashr
+}
+
+define i32 @ashr_mul_times_3_div_2_exact(i32 %x) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 1
+; CHECK-NEXT:    [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nsw i32 %x, 3
+  %ashr = ashr exact i32 %mul, 1
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @ashr_mul_times_3_div_2_no_flags(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2_no_flags(
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[TMP0:%.*]], 3
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i32 [[MUL]], 1
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul i32 %0, 3
+  %ashr = ashr i32 %mul, 1
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @ashr_mul_times_3_div_2_no_nsw(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2_no_nsw(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nuw i32 [[TMP0:%.*]], 3
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i32 [[MUL]], 1
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nuw i32 %0, 3
+  %ashr = ashr i32 %mul, 1
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @mul_times_3_div_2_multiuse_ashr(i32 %x) {
+; CHECK-LABEL: @mul_times_3_div_2_multiuse_ashr(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 3
+; CHECK-NEXT:    [[RES:%.*]] = ashr i32 [[MUL]], 1
+; CHECK-NEXT:    call void @use(i32 [[MUL]])
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %mul = mul nsw i32 %x, 3
+  %res = ashr i32 %mul, 1
+  call void @use(i32 %mul)
+  ret i32 %res
+}
+
+define i32 @ashr_mul_times_3_div_2_exact_2(i32 %x) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2_exact_2(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 1
+; CHECK-NEXT:    [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nsw i32 %x, 3
+  %ashr = ashr exact i32 %mul, 1
+  ret i32 %ashr
+}
+
+define i32 @ashr_mul_times_5_div_4(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_5_div_4(
+; CHECK-NEXT:    [[TMP2:%.*]] = ashr i32 [[TMP0:%.*]], 2
+; CHECK-NEXT:    [[ASHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nuw nsw i32 %0, 5
+  %ashr = ashr i32 %mul, 2
+  ret i32 %ashr
+}
+
+define i32 @ashr_mul_times_5_div_4_exact(i32 %x) {
+; CHECK-LABEL: @ashr_mul_times_5_div_4_exact(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nsw i32 %x, 5
+  %ashr = ashr exact i32 %mul, 2
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @ashr_mul_times_5_div_4_no_flags(i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_5_div_4_no_flags(
+; CHECK-NEXT:    [[MUL:%.*]] = mul i32 [[TMP0:%.*]], 5
+; CHECK-NEXT:    [[ASHR:%.*]] = ashr i32 [[MUL]], 2
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul i32 %0, 5
+  %ashr = ashr i32 %mul, 2
+  ret i32 %ashr
+}
+
+; Negative test
+
+define i32 @mul_times_5_div_4_multiuse_ashr(i32 %x) {
+; CHECK-LABEL: @mul_times_5_div_4_multiuse_ashr(
+; CHECK-NEXT:    [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 5
+; CHECK-NEXT:    [[RES:%.*]] = ashr i32 [[MUL]], 2
+; CHECK-NEXT:    call void @use(i32 [[MUL]])
+; CHECK-NEXT:    ret i32 [[RES]]
+;
+  %mul = mul nsw i32 %x, 5
+  %res = ashr i32 %mul, 2
+  call void @use(i32 %mul)
+  ret i32 %res
+}
+
+define i32 @ashr_mul_times_5_div_4_exact_2(i32 %x) {
+; CHECK-LABEL: @ashr_mul_times_5_div_4_exact_2(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 2
+; CHECK-NEXT:    [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[ASHR]]
+;
+  %mul = mul nsw i32 %x, 5
+  %ashr = ashr exact i32 %mul, 2
+  ret i32 %ashr
+}
+
+define i32 @mul_splat_fold_known_active_bits(i32 %x) {
+; CHECK-LABEL: @mul_splat_fold_known_active_bits(
+; CHECK-NEXT:    [[XX:%.*]] = and i32 [[X:%.*]], 360
+; CHECK-NEXT:    ret i32 [[XX]]
+;
+  %xx = and i32 %x, 360
+  %m = mul nuw i32 %xx, 65537
+  %t = ashr i32 %m, 16
+  ret i32 %t
+}
+
+define i32 @mul_splat_fold_no_known_active_bits(i32 %x) {
+; CHECK-LABEL: @mul_splat_fold_no_known_active_bits(
+; CHECK-NEXT:    [[TMP1:%.*]] = ashr i32 [[X:%.*]], 16
+; CHECK-NEXT:    [[T:%.*]] = add nsw i32 [[TMP1]], [[X]]
+; CHECK-NEXT:    ret i32 [[T]]
+;
+  %m = mul nsw i32 %x, 65537
+  %t = ashr i32 %m, 16
+  ret i32 %t
+}
+
+declare void @use(i32)
diff --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll
index fa92c1c4b3be4..17b08985ee90e 100644
--- a/llvm/test/Transforms/InstCombine/lshr.ll
+++ b/llvm/test/Transforms/InstCombine/lshr.ll
@@ -348,22 +348,31 @@ define <2 x i32> @narrow_lshr_constant(<2 x i8> %x, <2 x i8> %y) {
 
 define i32 @mul_splat_fold(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold(
-; CHECK-NEXT:    [[T:%.*]] = and i32 [[X:%.*]], 65535
-; CHECK-NEXT:    ret i32 [[T]]
+; CHECK-NEXT:    ret i32 [[X:%.*]]
 ;
   %m = mul nuw i32 %x, 65537
   %t = lshr i32 %m, 16
   ret i32 %t
 }
 
+define i32 @mul_splat_fold_known_zeros(i32 %x) {
+; CHECK-LABEL: @mul_splat_fold_known_zeros(
+; CHECK-NEXT:    [[XX:%.*]] = and i32 [[X:%.*]], 360
+; CHECK-NEXT:    ret i32 [[XX]]
+;
+  %xx = and i32 %x, 360
+  %m = mul nuw i32 %xx, 65537
+  %t = lshr i32 %m, 16
+  ret i32 %t
+}
+
 ; Vector type, extra use, weird types are all ok.
 
 define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {
 ; CHECK-LABEL: @mul_splat_fold_vec(
 ; CHECK-NEXT:    [[M:%.*]] = mul nuw <3 x i14> [[X:%.*]], <i14 129, i14 129, i14 129>
 ; CHECK-NEXT:    call void @usevec(<3 x i14> [[M]])
-; CHECK-NEXT:    [[T:%.*]] = and <3 x i14> [[X]], <i14 127, i14 127, i14 127>
-; CHECK-NEXT:    ret <3 x i14> [[T]]
+; CHECK-NEXT:    ret <3 x i14> [[X]]
 ;
   %m = mul nuw <3 x i14> %x, <i14 129, i14 129, i14 129>
   call void @usevec(<3 x i14> %m)
@@ -628,12 +637,10 @@ define i32 @mul_splat_fold_wrong_lshr_const(i32 %x) {
   ret i32 %t
 }
 
-; Negative test
-
 define i32 @mul_splat_fold_no_nuw(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold_no_nuw(
-; CHECK-NEXT:    [[M:%.*]] = mul nsw i32 [[X:%.*]], 65537
-; CHECK-NEXT:    [[T:%.*]] = lshr i32 [[M]], 16
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[X:%.*]], 16
+; CHECK-NEXT:    [[T:%.*]] = add nsw i32 [[TMP1]], [[X]]
 ; CHECK-NEXT:    ret i32 [[T]]
 ;
   %m = mul nsw i32 %x, 65537
@@ -641,6 +648,19 @@ define i32 @mul_splat_fold_no_nuw(i32 %x) {
   ret i32 %t
 }
 
+; Negative test
+
+define i32 @mul_splat_fold_no_flags(i32 %x) {
+; CHECK-LABEL: @mul_splat_fold_no_flags(
+; CHECK-NEXT:    [[M:%.*]] = mul i32 [[X:%.*]], 65537
+; CHECK-NEXT:    [[T:%.*]] = lshr i32 [[M]], 16
+; CHECK-NEXT:    ret i32 [[T]]
+;
+  %m = mul i32 %x, 65537
+  %t = lshr i32 %m, 16
+  ret i32 %t
+}
+
 ; Negative test (but simplifies before we reach the mul_splat transform)- need more than 2 bits
 
 define i2 @mul_splat_fold_too_narrow(i2 %x) {

@AZero13 AZero13 changed the title Fold X * (2^N + 1) >> N -> X when N is half the bitwidth of X [InstSimplify] Fold X * (2^N + 1) >> N -> X when N is half the bitwidth of X May 21, 2024
@AZero13 AZero13 force-pushed the mul branch 2 times, most recently from 84cd0b4 to 1937c3f Compare May 21, 2024 23:33
@AZero13
Copy link
Contributor Author

AZero13 commented May 27, 2024

@dtcxzyw What do you think about this?

@AZero13 AZero13 closed this May 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants