Skip to content

Commit

Permalink
[InstCombine] Fold adds + shifts with nsw and nuw flags
Browse files Browse the repository at this point in the history
I also added mul nsw/nuw 3, div 2 since this was the canonical version of ((x << 1) + x) / 2, which is a specific expression which canonicalization causes the InstCombine to miss it.

Proofs:
https://alive2.llvm.org/ce/z/kDVTiL
https://alive2.llvm.org/ce/z/wORNYm
  • Loading branch information
AZero13 committed Apr 21, 2024
1 parent 0bbe916 commit f15664c
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 15 deletions.
49 changes: 47 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,18 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
match(Op1, m_SpecificIntAllowPoison(BitWidth - 1)))
return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);

// If both the add and the shift are nuw, then:
// ((X << Y) + Z) nuw >>u Z --> X + (Y nuw >>u Z) nuw
Value *Y;
if (match(Op0, m_OneUse(m_c_NUWAdd(m_NUWShl(m_Value(X), m_Value(Y)),
m_Specific(Op1))))) {
Value *NewLshr = Builder.CreateLShr(Y, Op1, "", I.isExact());
auto *newAdd = BinaryOperator::CreateNUWAdd(NewLshr, X);
if (auto *Op0Bin = cast<OverflowingBinaryOperator>(Op0))
newAdd->setHasNoSignedWrap(Op0Bin->hasNoSignedWrap());
return newAdd;
}

if (match(Op1, m_APInt(C))) {
unsigned ShAmtC = C->getZExtValue();
auto *II = dyn_cast<IntrinsicInst>(Op0);
Expand All @@ -1283,7 +1295,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
return new ZExtInst(Cmp, Ty);
}

Value *X;
const APInt *C1;
if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
if (C1->ult(ShAmtC)) {
Expand Down Expand Up @@ -1328,7 +1339,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
// ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C)
// TODO: Consolidate with the more general transform that starts from shl
// (the shifts are in the opposite order).
Value *Y;
if (match(Op0,
m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))),
m_Value(Y))))) {
Expand Down Expand Up @@ -1450,9 +1460,24 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
NewMul->setHasNoSignedWrap(true);
return NewMul;
}

// Special case: lshr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1)
if (ShAmtC == 1 && MulC->getZExtValue() == 3) {
auto *NewAdd = BinaryOperator::CreateNUWAdd(
X,
Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact()));
NewAdd->setHasNoSignedWrap(true);
return NewAdd;
}
}
}

// // lshr nsw (mul (X, 3), 1) -> add nsw (X, lshr(X, 1)
if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) &&
ShAmtC == 1)
return BinaryOperator::CreateNSWAdd(
X, Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact()));

// Try to narrow bswap.
// In the case where the shift amount equals the bitwidth difference, the
// shift is eliminated.
Expand Down Expand Up @@ -1656,6 +1681,26 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
}

// Special case: ashr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1)
if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) &&
ShAmt == 1) {
Value *Shift;
if (auto *Op0Bin = cast<OverflowingBinaryOperator>(Op0)) {
if (Op0Bin->hasNoUnsignedWrap())
// We can use lshr if the mul is nuw and nsw
Shift =
Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact());
else
Shift =
Builder.CreateAShr(X, ConstantInt::get(Ty, 1), "", I.isExact());

auto *NewAdd = BinaryOperator::CreateNSWAdd(X, Shift);
NewAdd->setHasNoUnsignedWrap(Op0Bin->hasNoUnsignedWrap());

return NewAdd;
}
}
}

const SimplifyQuery Q = SQ.getWithInstruction(&I);
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/Transforms/InstCombine/ashr-lshr.ll
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,8 @@ define <2 x i8> @ashr_known_pos_exact_vec(<2 x i8> %x, <2 x i8> %y) {

define i32 @ashr_mul_times_3_div_2(i32 %0) {
; CHECK-LABEL: @ashr_mul_times_3_div_2(
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[TMP0:%.*]], 3
; CHECK-NEXT: [[ASHR:%.*]] = ashr i32 [[MUL]], 1
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
; CHECK-NEXT: [[ASHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
; CHECK-NEXT: ret i32 [[ASHR]]
;
%mul = mul nsw nuw i32 %0, 3
Expand All @@ -618,8 +618,8 @@ define i32 @ashr_mul_times_3_div_2(i32 %0) {

define i32 @ashr_mul_times_3_div_2_exact(i32 %x) {
; CHECK-LABEL: @ashr_mul_times_3_div_2_exact(
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 3
; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i32 [[MUL]], 1
; CHECK-NEXT: [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 1
; CHECK-NEXT: [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]]
; CHECK-NEXT: ret i32 [[ASHR]]
;
%mul = mul nsw i32 %x, 3
Expand Down
17 changes: 8 additions & 9 deletions llvm/test/Transforms/InstCombine/lshr.ll
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {

define i32 @mul_times_3_div_2(i32 %x) {
; CHECK-LABEL: @mul_times_3_div_2(
; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[X:%.*]], 3
; CHECK-NEXT: [[RES:%.*]] = lshr i32 [[MUL]], 1
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 1
; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[TMP1]], [[X]]
; CHECK-NEXT: ret i32 [[RES]]
;
%mul = mul nsw nuw i32 %x, 3
Expand All @@ -375,9 +375,8 @@ define i32 @mul_times_3_div_2(i32 %x) {

define i32 @shl_add_lshr(i32 %x, i32 %y, i32 %z) {
; CHECK-LABEL: @shl_add_lshr(
; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[Y:%.*]]
; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[SHL]], [[Z:%.*]]
; CHECK-NEXT: [[RES:%.*]] = lshr exact i32 [[ADD]], [[Z]]
; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i32 [[Y:%.*]], [[Z:%.*]]
; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[TMP1]], [[X:%.*]]
; CHECK-NEXT: ret i32 [[RES]]
;
%shl = shl nuw i32 %x, %y
Expand All @@ -388,8 +387,8 @@ define i32 @shl_add_lshr(i32 %x, i32 %y, i32 %z) {

define i32 @lshr_mul_times_3_div_2(i32 %0) {
; CHECK-LABEL: @lshr_mul_times_3_div_2(
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[TMP0:%.*]], 3
; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[MUL]], 1
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
; CHECK-NEXT: [[LSHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
; CHECK-NEXT: ret i32 [[LSHR]]
;
%mul = mul nuw i32 %0, 3
Expand Down Expand Up @@ -451,8 +450,8 @@ define i32 @mul_times_3_div_2_multiuse(i32 %x) {

define i32 @lshr_mul_times_3_div_2_nsw(i32 %0) {
; CHECK-LABEL: @lshr_mul_times_3_div_2_nsw(
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[TMP0:%.*]], 3
; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[MUL]], 1
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
; CHECK-NEXT: [[LSHR:%.*]] = add nsw i32 [[TMP2]], [[TMP0]]
; CHECK-NEXT: ret i32 [[LSHR]]
;
%mul = mul nsw i32 %0, 3
Expand Down

0 comments on commit f15664c

Please sign in to comment.