diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index da6f991ad4cd155..7c6f42de77fc71f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -739,6 +739,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Instruction *foldSelectOfBools(SelectInst &SI); Instruction *foldSelectToCmp(SelectInst &SI); Instruction *foldSelectExtConst(SelectInst &Sel); + Instruction *foldSelectEqualityTest(SelectInst &SI); Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *); Instruction *foldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 3dbe95897d63567..3f780285efe4235 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1410,6 +1410,44 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, return nullptr; } +/// Fold the following code sequence: +/// \code +/// %XeqZ = icmp eq i64 %X, %Z +/// %YeqZ = icmp eq i64 %Y, %Z +/// %XeqY = icmp eq i64 %X, %Y +/// %not.YeqZ = xor i1 %YeqZ, true +/// %and = select i1 %not.YeqZ, i1 %XeqY, i1 false +/// %equal = select i1 %XeqZ, i1 %YeqZ, i1 %and +/// \code +/// +/// into: +/// %equal = icmp eq i64 %X, %Y +Instruction *InstCombinerImpl::foldSelectEqualityTest(SelectInst &Sel) { + Value *X, *Y, *Z; + Value *XeqY, *XeqZ = Sel.getCondition(), *YeqZ = Sel.getTrueValue(); + + if (!match(XeqZ, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(X), m_Value(Z)))) + return nullptr; + + if (!match(YeqZ, + m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(Y), m_Specific(Z)))) + std::swap(X, Z); + + if (!match(YeqZ, + m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(Y), m_Specific(Z)))) + return nullptr; + + if (!match(Sel.getFalseValue(), + m_c_LogicalAnd(m_Not(m_Specific(YeqZ)), m_Value(XeqY)))) + return nullptr; + + if (!match(XeqY, + m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(X), m_Specific(Y)))) + return nullptr; + + return replaceInstUsesWith(Sel, XeqY); +} + // See if this is a pattern like: // %old_cmp1 = icmp slt i32 %x, C2 // %old_replacement = select i1 %old_cmp1, i32 %target_low, i32 %target_high @@ -4112,6 +4150,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *I = foldSelectToCmp(SI)) return I; + if (Instruction *I = foldSelectEqualityTest(SI)) + return I; + // Fold: // (select A && B, T, F) -> (select A, (select B, T, F), F) // (select A || B, T, F) -> (select A, T, (select B, T, F)) diff --git a/llvm/test/Transforms/InstCombine/icmp-equality-test.ll b/llvm/test/Transforms/InstCombine/icmp-equality-test.ll new file mode 100644 index 000000000000000..c2740ca7fe8aa9d --- /dev/null +++ b/llvm/test/Transforms/InstCombine/icmp-equality-test.ll @@ -0,0 +1,229 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define i1 @icmp_equality_test(i64 %X, i64 %Y, i64 %Z) { +; CHECK-LABEL: @icmp_equality_test( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] +; +entry: + %XeqZ = icmp eq i64 %X, %Z + %YeqZ = icmp eq i64 %Y, %Z + %XeqY = icmp eq i64 %X, %Y + %not.YeqZ = xor i1 %YeqZ, true + %and = select i1 %not.YeqZ, i1 %XeqY, i1 false + %equal = select i1 %XeqZ, i1 %YeqZ, i1 %and + ret i1 %equal +} + +define i1 @icmp_equality_test_constant(i42 %X, i42 %Y) { +; CHECK-LABEL: @icmp_equality_test_constant( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i42 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] +; +entry: + %XeqC = icmp eq i42 %X, -42 + %YeqC = icmp eq i42 %Y, -42 + %XeqY = icmp eq i42 %X, %Y + %not.YeqC = xor i1 %YeqC, true + %and = select i1 %not.YeqC, i1 %XeqY, i1 false + %equal = select i1 %XeqC, i1 %YeqC, i1 %and + ret i1 %equal +} + +define i1 @icmp_equality_test_swift_optional_pointers(i64 %X, i64 %Y) { +; CHECK-LABEL: @icmp_equality_test_swift_optional_pointers( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] +; +entry: + %XeqC = icmp eq i64 %X, 0 + %YeqC = icmp eq i64 %Y, 0 + %either = select i1 %XeqC, i1 true, i1 %YeqC + %both = select i1 %XeqC, i1 %YeqC, i1 false + %XeqY = icmp eq i64 %X, %Y + %equal = select i1 %either, i1 %both, i1 %XeqY + ret i1 %equal +} + +define <2 x i1> @icmp_equality_test_vector(<2 x i64> %X, <2 x i64> %Y) { +; CHECK-LABEL: @icmp_equality_test_vector( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq <2 x i64> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret <2 x i1> [[XEQY]] +; +entry: + %XeqC = icmp eq <2 x i64> %X, + %YeqC = icmp eq <2 x i64> %Y, + %XeqY = icmp eq <2 x i64> %X, %Y + %not.YeqC = xor <2 x i1> %YeqC, + %and = select <2 x i1> %not.YeqC, <2 x i1> %XeqY, <2 x i1> + %equal = select <2 x i1> %XeqC, <2 x i1> %YeqC, <2 x i1> %and + ret <2 x i1> %equal +} + +define i1 @icmp_equality_test_commute_icmp1(i64 %X, i64 %Y, i64 %Z) { +; CHECK-LABEL: @icmp_equality_test_commute_icmp1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] +; +entry: + %XeqZ = icmp eq i64 %Z, %X + %YeqZ = icmp eq i64 %Z, %Y + %XeqY = icmp eq i64 %Y, %X + %not.YeqZ = xor i1 %YeqZ, true + %and = select i1 %not.YeqZ, i1 %XeqY, i1 false + %equal = select i1 %XeqZ, i1 %YeqZ, i1 %and + ret i1 %equal +} + +define i1 @icmp_equality_test_commute_icmp2(i64 %X, i64 %Y, i64 %Z) { +; CHECK-LABEL: @icmp_equality_test_commute_icmp2( +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] +; + %XeqZ = icmp eq i64 %Z, %X + %YeqZ = icmp eq i64 %Y, %Z + %XeqY = icmp eq i64 %Y, %X + %not.YeqZ = xor i1 %YeqZ, true + %and = select i1 %not.YeqZ, i1 %XeqY, i1 false + %equal = select i1 %XeqZ, i1 %YeqZ, i1 %and + ret i1 %equal +} + +define i1 @icmp_equality_test_commute_select1(i64 %X, i64 %Y, i64 %Z) { +; CHECK-LABEL: @icmp_equality_test_commute_select1( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] +; +entry: + %XeqZ = icmp eq i64 %X, %Z + %YeqZ = icmp eq i64 %Y, %Z + %XeqY = icmp eq i64 %X, %Y + %and = select i1 %YeqZ, i1 false, i1 %XeqY + %equal = select i1 %XeqZ, i1 %YeqZ, i1 %and + ret i1 %equal +} + +define i1 @icmp_equality_test_commute_select2(i64 %X, i64 %Y, i64 %Z) { +; CHECK-LABEL: @icmp_equality_test_commute_select2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] +; +entry: + %XeqZ = icmp eq i64 %X, %Z + %YeqZ = icmp eq i64 %Y, %Z + %XeqY = icmp eq i64 %X, %Y + %not.XeqZ = xor i1 %XeqZ, true + %and = select i1 %YeqZ, i1 false, i1 %XeqY + %equal = select i1 %not.XeqZ, i1 %and, i1 %YeqZ + ret i1 %equal +} + +; Negative tests below + +define i1 @icmp_equality_test_wrong_constant(i64 %X, i64 %Y) { +; CHECK-LABEL: @icmp_equality_test_wrong_constant( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQC:%.*]] = icmp eq i64 [[X:%.*]], 0 +; CHECK-NEXT: [[YEQC:%.*]] = icmp eq i64 [[Y:%.*]], 999 +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X]], [[Y]] +; CHECK-NEXT: [[NOT_YEQC:%.*]] = xor i1 [[YEQC]], true +; CHECK-NEXT: [[AND:%.*]] = select i1 [[NOT_YEQC]], i1 [[XEQY]], i1 false +; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQC]], i1 [[YEQC]], i1 [[AND]] +; CHECK-NEXT: ret i1 [[EQUAL]] +; +entry: + %XeqC = icmp eq i64 %X, 0 + %YeqC = icmp eq i64 %Y, 999 + %XeqY = icmp eq i64 %X, %Y + %not.YeqC = xor i1 %YeqC, true + %and = select i1 %not.YeqC, i1 %XeqY, i1 false + %equal = select i1 %XeqC, i1 %YeqC, i1 %and + ret i1 %equal +} + +define i1 @icmp_equality_test_missing_not(i64 %X, i64 %Y, i64 %Z) { +; CHECK-LABEL: @icmp_equality_test_missing_not( +; CHECK-NEXT: [[XEQZ:%.*]] = icmp eq i64 [[X:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[YEQZ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X]], [[Y]] +; CHECK-NEXT: [[AND:%.*]] = select i1 [[YEQZ]], i1 [[XEQY]], i1 false +; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQZ]], i1 [[YEQZ]], i1 [[AND]] +; CHECK-NEXT: ret i1 [[EQUAL]] +; + %XeqZ = icmp eq i64 %X, %Z + %YeqZ = icmp eq i64 %Y, %Z + %XeqY = icmp eq i64 %X, %Y + %and = select i1 %YeqZ, i1 %XeqY, i1 false + %equal = select i1 %XeqZ, i1 %YeqZ, i1 %and + ret i1 %equal +} + +define i1 @icmp_equality_test_wrong_and(i64 %X, i64 %Y, i64 %Z) { +; CHECK-LABEL: @icmp_equality_test_wrong_and( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQZ:%.*]] = icmp eq i64 [[X:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[YEQZ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X]], [[Y]] +; CHECK-NEXT: [[AND:%.*]] = select i1 [[YEQZ]], i1 [[XEQY]], i1 false +; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQZ]], i1 [[YEQZ]], i1 [[AND]] +; CHECK-NEXT: ret i1 [[EQUAL]] +; +entry: + %XeqZ = icmp eq i64 %X, %Z + %YeqZ = icmp eq i64 %Y, %Z + %XeqY = icmp eq i64 %X, %Y + %not.YeqZ = xor i1 %YeqZ, true + %and = select i1 %not.YeqZ, i1 false, i1 %XeqY + %equal = select i1 %XeqZ, i1 %YeqZ, i1 %and + ret i1 %equal +} + +define i1 @icmp_equality_test_wrong_cmp(i64 %X, i64 %Y, i64 %Z) { +; CHECK-LABEL: @icmp_equality_test_wrong_cmp( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQZ:%.*]] = icmp eq i64 [[X:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[YEQZ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X]], 999 +; CHECK-NEXT: [[NOT_YEQZ:%.*]] = xor i1 [[YEQZ]], true +; CHECK-NEXT: [[AND:%.*]] = select i1 [[NOT_YEQZ]], i1 [[XEQY]], i1 false +; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQZ]], i1 [[YEQZ]], i1 [[AND]] +; CHECK-NEXT: ret i1 [[EQUAL]] +; +entry: + %XeqZ = icmp eq i64 %X, %Z + %YeqZ = icmp eq i64 %Y, %Z + %XeqY = icmp eq i64 %X, 999 + %not.YeqZ = xor i1 %YeqZ, true + %and = select i1 %not.YeqZ, i1 %XeqY, i1 false + %equal = select i1 %XeqZ, i1 %YeqZ, i1 %and + ret i1 %equal +} + +define i1 @icmp_equality_test_wrong_equal(i64 %X, i64 %Y, i64 %Z) { +; CHECK-LABEL: @icmp_equality_test_wrong_equal( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[XEQZ:%.*]] = icmp eq i64 [[X:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[YEQZ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X]], [[Y]] +; CHECK-NEXT: [[NOT_YEQZ:%.*]] = xor i1 [[YEQZ]], true +; CHECK-NEXT: [[AND:%.*]] = select i1 [[NOT_YEQZ]], i1 [[XEQY]], i1 false +; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQZ]], i1 [[AND]], i1 [[YEQZ]] +; CHECK-NEXT: ret i1 [[EQUAL]] +; +entry: + %XeqZ = icmp eq i64 %X, %Z + %YeqZ = icmp eq i64 %Y, %Z + %XeqY = icmp eq i64 %X, %Y + %not.YeqZ = xor i1 %YeqZ, true + %and = select i1 %not.YeqZ, i1 %XeqY, i1 false + %equal = select i1 %XeqZ, i1 %and, i1 %YeqZ + ret i1 %equal +}