diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index a051a568bfd62ee..9bd06a0abb5e530 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -734,6 +734,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Instruction *foldSelectOfBools(SelectInst &SI); Instruction *foldSelectToCmp(SelectInst &SI); Instruction *foldSelectExtConst(SelectInst &Sel); + Instruction *foldSelectEqualityTest(SelectInst &SI); Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *); Instruction *foldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 66f7c4592457c20..0598b678f0111e2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1406,6 +1406,48 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, return nullptr; } +/// Fold the following code sequence: +/// \code +/// %XEq = icmp eq i64 %X, %Z +/// %YEq = icmp eq i64 %Y, %Z +/// %either = select i1 %XEq, i1 true, i1 %YEq +/// %both = select i1 %XEq, i1 %YEq, i1 false +/// %cmp = icmp eq i64 %X, %Y +/// %equal = select i1 %either, i1 %both, i1 %cmp +/// \code +/// +/// into: +/// %equal = icmp eq i64 %X, %Y +/// +/// Equivalently: +/// (X==Z || Y==Z) ? (X==Z && Y==Z) : X==Y --> X==Y +Instruction *InstCombinerImpl::foldSelectEqualityTest(SelectInst &Sel) { + Value *X, *Y, *Z, *XEq, *YEq; + Value *Either = Sel.getCondition(), *Both = Sel.getTrueValue(), + *Cmp = Sel.getFalseValue(); + + if (!match(Either, m_LogicalOr(m_Value(XEq), m_Value(YEq)))) + return nullptr; + + if (!match(XEq, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(X), m_Value(Z)))) + return nullptr; + if (!match(YEq, + m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(Y), m_Specific(Z)))) + std::swap(X, Z); + if (!match(YEq, + m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(Y), m_Specific(Z)))) + return nullptr; + + if (!match(Both, m_c_LogicalAnd(m_Specific(XEq), m_Specific(YEq)))) + return nullptr; + + if (!match(Cmp, + m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(X), m_Specific(Y)))) + return nullptr; + + return replaceInstUsesWith(Sel, Cmp); +} + // See if this is a pattern like: // %old_cmp1 = icmp slt i32 %x, C2 // %old_replacement = select i1 %old_cmp1, i32 %target_low, i32 %target_high @@ -4068,6 +4110,11 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *I = foldSelectOfSymmetricSelect(SI, Builder)) return I; + // This needs to happen before foldNestedSelects, as that could break the + // patterns that we test for. + if (Instruction *I = foldSelectEqualityTest(SI)) + return I; + if (Instruction *I = foldNestedSelects(SI, Builder)) return I; diff --git a/llvm/test/Transforms/InstCombine/icmp-equality-test.ll b/llvm/test/Transforms/InstCombine/icmp-equality-test.ll index c0d14e96552abb6..f822870273801fc 100644 --- a/llvm/test/Transforms/InstCombine/icmp-equality-test.ll +++ b/llvm/test/Transforms/InstCombine/icmp-equality-test.ll @@ -4,13 +4,8 @@ define i1 @icmp_equality_test(i64 %X, i64 %Y, i64 %Z) { ; CHECK-LABEL: @icmp_equality_test( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQ:%.*]] = icmp eq i64 [[X:%.*]], [[Z:%.*]] -; CHECK-NEXT: [[YEQ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQ:%.*]] = xor i1 [[YEQ]], true -; CHECK-NEXT: [[BOTH:%.*]] = select i1 [[NOT_YEQ]], i1 [[CMP]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQ]], i1 [[YEQ]], i1 [[BOTH]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[CMP]] ; entry: %XEq = icmp eq i64 %X, %Z @@ -25,13 +20,8 @@ entry: define i1 @icmp_equality_test_constant(i42 %X, i42 %Y) { ; CHECK-LABEL: @icmp_equality_test_constant( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQ:%.*]] = icmp eq i42 [[X:%.*]], -42 -; CHECK-NEXT: [[YEQ:%.*]] = icmp eq i42 [[Y:%.*]], -42 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i42 [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQ:%.*]] = xor i1 [[YEQ]], true -; CHECK-NEXT: [[BOTH:%.*]] = select i1 [[NOT_YEQ]], i1 [[CMP]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQ]], i1 [[YEQ]], i1 [[BOTH]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i42 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[CMP]] ; entry: %XEq = icmp eq i42 %X, -42 @@ -46,13 +36,8 @@ entry: define <2 x i1> @icmp_equality_test_vector(<2 x i64> %X, <2 x i64> %Y) { ; CHECK-LABEL: @icmp_equality_test_vector( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQ:%.*]] = icmp eq <2 x i64> [[X:%.*]], -; CHECK-NEXT: [[YEQ:%.*]] = icmp eq <2 x i64> [[Y:%.*]], -; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQ:%.*]] = xor <2 x i1> [[YEQ]], -; CHECK-NEXT: [[BOTH:%.*]] = select <2 x i1> [[NOT_YEQ]], <2 x i1> [[CMP]], <2 x i1> zeroinitializer -; CHECK-NEXT: [[EQUAL:%.*]] = select <2 x i1> [[XEQ]], <2 x i1> [[YEQ]], <2 x i1> [[BOTH]] -; CHECK-NEXT: ret <2 x i1> [[EQUAL]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret <2 x i1> [[CMP]] ; entry: %XEq = icmp eq <2 x i64> %X, @@ -67,13 +52,8 @@ entry: define i1 @icmp_equality_test_commute_icmp1(i64 %X, i64 %Y, i64 %Z) { ; CHECK-LABEL: @icmp_equality_test_commute_icmp1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQ:%.*]] = icmp eq i64 [[Z:%.*]], [[X:%.*]] -; CHECK-NEXT: [[YEQ:%.*]] = icmp eq i64 [[Z]], [[Y:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[Y]], [[X]] -; CHECK-NEXT: [[NOT_YEQ:%.*]] = xor i1 [[YEQ]], true -; CHECK-NEXT: [[BOTH:%.*]] = select i1 [[NOT_YEQ]], i1 [[CMP]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQ]], i1 [[YEQ]], i1 [[BOTH]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[CMP]] ; entry: %XEq = icmp eq i64 %Z, %X @@ -88,13 +68,8 @@ entry: define i1 @icmp_equality_test_commute_icmp2(i64 %X, i64 %Y, i64 %Z) { ; CHECK-LABEL: @icmp_equality_test_commute_icmp2( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQ:%.*]] = icmp eq i64 [[Z:%.*]], [[X:%.*]] -; CHECK-NEXT: [[YEQ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[Y]], [[X]] -; CHECK-NEXT: [[NOT_YEQ:%.*]] = xor i1 [[YEQ]], true -; CHECK-NEXT: [[BOTH:%.*]] = select i1 [[NOT_YEQ]], i1 [[CMP]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQ]], i1 [[YEQ]], i1 [[BOTH]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[CMP]] ; entry: %XEq = icmp eq i64 %Z, %X @@ -109,13 +84,8 @@ entry: define i1 @icmp_equality_test_commute_select1(i64 %X, i64 %Y) { ; CHECK-LABEL: @icmp_equality_test_commute_select1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQ:%.*]] = icmp eq i64 [[X:%.*]], 0 -; CHECK-NEXT: [[YEQ:%.*]] = icmp eq i64 [[Y:%.*]], 0 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQ:%.*]] = xor i1 [[YEQ]], true -; CHECK-NEXT: [[BOTH:%.*]] = select i1 [[NOT_YEQ]], i1 [[CMP]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQ]], i1 [[YEQ]], i1 [[BOTH]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[CMP]] ; entry: %XEq = icmp eq i64 %X, 0 @@ -130,13 +100,8 @@ entry: define i1 @icmp_equality_test_commute_select2(i64 %X, i64 %Y) { ; CHECK-LABEL: @icmp_equality_test_commute_select2( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQ:%.*]] = icmp eq i64 [[X:%.*]], 0 -; CHECK-NEXT: [[YEQ:%.*]] = icmp eq i64 [[Y:%.*]], 0 -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[X]], [[Y]] -; CHECK-NEXT: [[NOT_XEQ:%.*]] = xor i1 [[XEQ]], true -; CHECK-NEXT: [[BOTH:%.*]] = select i1 [[NOT_XEQ]], i1 [[CMP]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[YEQ]], i1 [[XEQ]], i1 [[BOTH]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[CMP]] ; entry: %XEq = icmp eq i64 %X, 0