diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 244f03a1bc2b4c..2e64f02edda376 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -1267,6 +1267,18 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { match(Op1, m_SpecificIntAllowPoison(BitWidth - 1))) return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty); + // If both the add and the shift are nuw, then: + // ((X << Y) + Z) nuw >>u Z --> X + (Y nuw >>u Z) nuw + Value *Y; + if (match(Op0, m_OneUse(m_c_NUWAdd(m_NUWShl(m_Value(X), m_Value(Y)), + m_Specific(Op1))))) { + Value *NewLshr = Builder.CreateLShr(Y, Op1, "", I.isExact()); + auto *newAdd = BinaryOperator::CreateNUWAdd(NewLshr, X); + if (auto *Op0Bin = cast(Op0)) + newAdd->setHasNoSignedWrap(Op0Bin->hasNoSignedWrap()); + return newAdd; + } + if (match(Op1, m_APInt(C))) { unsigned ShAmtC = C->getZExtValue(); auto *II = dyn_cast(Op0); @@ -1283,7 +1295,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { return new ZExtInst(Cmp, Ty); } - Value *X; const APInt *C1; if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { if (C1->ult(ShAmtC)) { @@ -1328,7 +1339,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { // ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C) // TODO: Consolidate with the more general transform that starts from shl // (the shifts are in the opposite order). - Value *Y; if (match(Op0, m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))), m_Value(Y))))) { @@ -1450,9 +1460,24 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { NewMul->setHasNoSignedWrap(true); return NewMul; } + + // Special case: lshr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1) + if (ShAmtC == 1 && MulC->getZExtValue() == 3) { + auto *NewAdd = BinaryOperator::CreateNUWAdd( + X, + Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact())); + NewAdd->setHasNoSignedWrap(true); + return NewAdd; + } } } + // // lshr nsw (mul (X, 3), 1) -> add nsw (X, lshr(X, 1) + if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) && + ShAmtC == 1) + return BinaryOperator::CreateNSWAdd( + X, Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact())); + // Try to narrow bswap. // In the case where the shift amount equals the bitwidth difference, the // shift is eliminated. @@ -1656,6 +1681,26 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y))))) return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty); } + + // Special case: ashr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1) + if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) && + ShAmt == 1) { + Value *Shift; + if (auto *Op0Bin = cast(Op0)) { + if (Op0Bin->hasNoUnsignedWrap()) + // We can use lshr if the mul is nuw and nsw + Shift = + Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact()); + else + Shift = + Builder.CreateAShr(X, ConstantInt::get(Ty, 1), "", I.isExact()); + + auto *NewAdd = BinaryOperator::CreateNSWAdd(X, Shift); + NewAdd->setHasNoUnsignedWrap(Op0Bin->hasNoUnsignedWrap()); + + return NewAdd; + } + } } const SimplifyQuery Q = SQ.getWithInstruction(&I); diff --git a/llvm/test/Transforms/InstCombine/ashr-lshr.ll b/llvm/test/Transforms/InstCombine/ashr-lshr.ll index 7dd62327521081..25f53074f4e794 100644 --- a/llvm/test/Transforms/InstCombine/ashr-lshr.ll +++ b/llvm/test/Transforms/InstCombine/ashr-lshr.ll @@ -607,8 +607,8 @@ define <2 x i8> @ashr_known_pos_exact_vec(<2 x i8> %x, <2 x i8> %y) { define i32 @ashr_mul_times_3_div_2(i32 %0) { ; CHECK-LABEL: @ashr_mul_times_3_div_2( -; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[TMP0:%.*]], 3 -; CHECK-NEXT: [[ASHR:%.*]] = ashr i32 [[MUL]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1 +; CHECK-NEXT: [[ASHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]] ; CHECK-NEXT: ret i32 [[ASHR]] ; %mul = mul nsw nuw i32 %0, 3 @@ -618,8 +618,8 @@ define i32 @ashr_mul_times_3_div_2(i32 %0) { define i32 @ashr_mul_times_3_div_2_exact(i32 %x) { ; CHECK-LABEL: @ashr_mul_times_3_div_2_exact( -; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], 3 -; CHECK-NEXT: [[ASHR:%.*]] = ashr exact i32 [[MUL]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = ashr exact i32 [[X:%.*]], 1 +; CHECK-NEXT: [[ASHR:%.*]] = add nsw i32 [[TMP1]], [[X]] ; CHECK-NEXT: ret i32 [[ASHR]] ; %mul = mul nsw i32 %x, 3 diff --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll index 11f16fffeb5510..3fe4e9be146c0a 100644 --- a/llvm/test/Transforms/InstCombine/lshr.ll +++ b/llvm/test/Transforms/InstCombine/lshr.ll @@ -364,8 +364,8 @@ define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) { define i32 @mul_times_3_div_2(i32 %x) { ; CHECK-LABEL: @mul_times_3_div_2( -; CHECK-NEXT: [[MUL:%.*]] = mul nuw nsw i32 [[X:%.*]], 3 -; CHECK-NEXT: [[RES:%.*]] = lshr i32 [[MUL]], 1 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[X:%.*]], 1 +; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[TMP1]], [[X]] ; CHECK-NEXT: ret i32 [[RES]] ; %mul = mul nsw nuw i32 %x, 3 @@ -375,9 +375,8 @@ define i32 @mul_times_3_div_2(i32 %x) { define i32 @shl_add_lshr(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @shl_add_lshr( -; CHECK-NEXT: [[SHL:%.*]] = shl nuw i32 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i32 [[SHL]], [[Z:%.*]] -; CHECK-NEXT: [[RES:%.*]] = lshr exact i32 [[ADD]], [[Z]] +; CHECK-NEXT: [[TMP1:%.*]] = lshr exact i32 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[RES:%.*]] = add nuw nsw i32 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i32 [[RES]] ; %shl = shl nuw i32 %x, %y @@ -388,8 +387,8 @@ define i32 @shl_add_lshr(i32 %x, i32 %y, i32 %z) { define i32 @lshr_mul_times_3_div_2(i32 %0) { ; CHECK-LABEL: @lshr_mul_times_3_div_2( -; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[TMP0:%.*]], 3 -; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[MUL]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1 +; CHECK-NEXT: [[LSHR:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]] ; CHECK-NEXT: ret i32 [[LSHR]] ; %mul = mul nuw i32 %0, 3 @@ -451,8 +450,8 @@ define i32 @mul_times_3_div_2_multiuse(i32 %x) { define i32 @lshr_mul_times_3_div_2_nsw(i32 %0) { ; CHECK-LABEL: @lshr_mul_times_3_div_2_nsw( -; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[TMP0:%.*]], 3 -; CHECK-NEXT: [[LSHR:%.*]] = lshr i32 [[MUL]], 1 +; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1 +; CHECK-NEXT: [[LSHR:%.*]] = add nsw i32 [[TMP2]], [[TMP0]] ; CHECK-NEXT: ret i32 [[LSHR]] ; %mul = mul nsw i32 %0, 3