Skip to content

Commit

Permalink
[InstCombine] Handle equality comparison when flooring by constant 2
Browse files Browse the repository at this point in the history
Support `icmp eq` when reducing signed divisions by power of 2 to
arithmetic shift right, as `icmp ugt` may have been canonicalized
into `icmp eq` by the time additions are folded into `ashr`.

Fixes: #73622.

Proof: https://alive2.llvm.org/ce/z/8-eUdb.
  • Loading branch information
antoniofrighetto committed Nov 30, 2023
1 parent e78a45d commit 7d5f79f
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 6 deletions.
22 changes: 16 additions & 6 deletions llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1234,18 +1234,28 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) {
return nullptr;

// Rounding is done by adding -1 if the dividend (X) is negative and has any
// low bits set. The canonical pattern for that is an "ugt" compare with SMIN:
// sext (icmp ugt (X & (DivC - 1)), SMIN)
const APInt *MaskC;
// low bits set. It recognizes two canonical patterns:
// 1. For an 'ugt' cmp with the signed minimum value (SMIN), the
// pattern is: sext (icmp ugt (X & (DivC - 1)), SMIN).
// 2. For an 'eq' cmp, the pattern's: sext (icmp eq X & (SMIN + 1), SMIN + 1).
// Note that, by the time we end up here, if possible, ugt has been
// canonicalized into eq.
const APInt *MaskC, *MaskCCmp;
ICmpInst::Predicate Pred;
if (!match(Add.getOperand(1),
m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)),
m_SignMask()))) ||
Pred != ICmpInst::ICMP_UGT)
m_APInt(MaskCCmp)))))
return nullptr;

if ((Pred != ICmpInst::ICMP_UGT || !MaskCCmp->isSignMask()) &&
(Pred != ICmpInst::ICMP_EQ || *MaskCCmp != *MaskC))
return nullptr;

APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits());
if (*MaskC != (SMin | (*DivC - 1)))
bool IsMaskValid = Pred == ICmpInst::ICMP_UGT
? (*MaskC == (SMin | (*DivC - 1)))
: (*DivC == 2 && *MaskC == SMin + 1);
if (!IsMaskValid)
return nullptr;

// (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC)
Expand Down
64 changes: 64 additions & 0 deletions llvm/test/Transforms/InstCombine/add.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2700,6 +2700,70 @@ define i32 @floor_sdiv(i32 %x) {
ret i32 %r
}

define i8 @floor_sdiv_by_2(i8 %x) {
; CHECK-LABEL: @floor_sdiv_by_2(
; CHECK-NEXT: [[RV:%.*]] = ashr i8 [[X:%.*]], 1
; CHECK-NEXT: ret i8 [[RV]]
;
%div = sdiv i8 %x, 2
%and = and i8 %x, -127
%icmp = icmp eq i8 %and, -127
%sext = sext i1 %icmp to i8
%rv = add nsw i8 %div, %sext
ret i8 %rv
}

define i8 @floor_sdiv_by_2_wrong_mask(i8 %x) {
; CHECK-LABEL: @floor_sdiv_by_2_wrong_mask(
; CHECK-NEXT: [[DIV:%.*]] = sdiv i8 [[X:%.*]], 2
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 127
; CHECK-NEXT: [[ICMP:%.*]] = icmp eq i8 [[AND]], 127
; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[ICMP]] to i8
; CHECK-NEXT: [[RV:%.*]] = add nsw i8 [[DIV]], [[SEXT]]
; CHECK-NEXT: ret i8 [[RV]]
;
%div = sdiv i8 %x, 2
%and = and i8 %x, 127
%icmp = icmp eq i8 %and, 127
%sext = sext i1 %icmp to i8
%rv = add nsw i8 %div, %sext
ret i8 %rv
}

define i8 @floor_sdiv_by_2_wrong_constant(i8 %x) {
; CHECK-LABEL: @floor_sdiv_by_2_wrong_constant(
; CHECK-NEXT: [[DIV:%.*]] = sdiv i8 [[X:%.*]], 4
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], -125
; CHECK-NEXT: [[ICMP:%.*]] = icmp eq i8 [[AND]], -125
; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[ICMP]] to i8
; CHECK-NEXT: [[RV:%.*]] = add nsw i8 [[DIV]], [[SEXT]]
; CHECK-NEXT: ret i8 [[RV]]
;
%div = sdiv i8 %x, 4
%and = and i8 %x, -125
%icmp = icmp eq i8 %and, -125
%sext = sext i1 %icmp to i8
%rv = add nsw i8 %div, %sext
ret i8 %rv
}

define i8 @floor_sdiv_by_2_wrong_cast(i8 %x) {
; CHECK-LABEL: @floor_sdiv_by_2_wrong_cast(
; CHECK-NEXT: [[DIV:%.*]] = sdiv i8 [[X:%.*]], 2
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], -127
; CHECK-NEXT: [[ICMP:%.*]] = icmp eq i8 [[AND]], -127
; CHECK-NEXT: [[SEXT:%.*]] = zext i1 [[ICMP]] to i8
; CHECK-NEXT: [[RV:%.*]] = add nsw i8 [[DIV]], [[SEXT]]
; CHECK-NEXT: ret i8 [[RV]]
;
%div = sdiv i8 %x, 2
%and = and i8 %x, -127
%icmp = icmp eq i8 %and, -127
%sext = zext i1 %icmp to i8
%rv = add nsw i8 %div, %sext
ret i8 %rv
}

; vectors work too and commute is handled by complexity-based canonicalization

define <2 x i32> @floor_sdiv_vec_commute(<2 x i32> %x) {
Expand Down

0 comments on commit 7d5f79f

Please sign in to comment.