From f8cd042ed1af7708260ace83a0f26fe5357e89bf Mon Sep 17 00:00:00 2001 From: Marina Taylor Date: Fri, 20 Sep 2024 15:47:42 +0100 Subject: [PATCH] [InstCombine] Fold `(X==Z) ? (Y==Z) : (!(Y==Z) && X==Y) --> X==Y` This corresponds to the canonicalized form of some logic that was seen in Swift-generated code for comparing optional pointers: `(X==Z || Y==Z) ? (X==Z && Y==Z) : X==Y --> X==Y` where `Z` was the constant `0`. https://alive2.llvm.org/ce/z/J_3aa9 --- .../InstCombine/InstCombineInternal.h | 1 + .../InstCombine/InstCombineSelect.cpp | 41 +++++++++++ .../InstCombine/icmp-equality-test.ll | 72 +++++-------------- 3 files changed, 58 insertions(+), 56 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index a051a568bfd62ee..9bd06a0abb5e530 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -734,6 +734,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final Instruction *foldSelectOfBools(SelectInst &SI); Instruction *foldSelectToCmp(SelectInst &SI); Instruction *foldSelectExtConst(SelectInst &Sel); + Instruction *foldSelectEqualityTest(SelectInst &SI); Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *); Instruction *foldSPFofSPF(Instruction *Inner, SelectPatternFlavor SPF1, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 66f7c4592457c20..3c8034c9c42f89c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1406,6 +1406,44 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, return nullptr; } +/// Fold the following code sequence: +/// \code +/// %XeqZ = icmp eq i64 %X, %Z +/// %YeqZ = icmp eq i64 %Y, %Z +/// %XeqY = icmp eq i64 %X, %Y +/// %not.YeqZ = xor i1 %YeqZ, true +/// %and = select i1 %not.YeqZ, i1 %XeqY, i1 false +/// %equal = select i1 %XeqZ, i1 %YeqZ, i1 %and +/// \code +/// +/// into: +/// %equal = icmp eq i64 %X, %Y +Instruction *InstCombinerImpl::foldSelectEqualityTest(SelectInst &Sel) { + Value *X, *Y, *Z; + Value *XeqY, *XeqZ = Sel.getCondition(), *YeqZ = Sel.getTrueValue(); + + if (!match(XeqZ, m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(X), m_Value(Z)))) + return nullptr; + + if (!match(YeqZ, + m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(Y), m_Specific(Z)))) + std::swap(X, Z); + + if (!match(YeqZ, + m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(Y), m_Specific(Z)))) + return nullptr; + + if (!match(Sel.getFalseValue(), + m_c_LogicalAnd(m_Not(m_Specific(YeqZ)), m_Value(XeqY)))) + return nullptr; + + if (!match(XeqY, + m_c_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(X), m_Specific(Y)))) + return nullptr; + + return replaceInstUsesWith(Sel, XeqY); +} + // See if this is a pattern like: // %old_cmp1 = icmp slt i32 %x, C2 // %old_replacement = select i1 %old_cmp1, i32 %target_low, i32 %target_high @@ -4084,6 +4122,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *I = foldSelectToCmp(SI)) return I; + if (Instruction *I = foldSelectEqualityTest(SI)) + return I; + // Fold: // (select A && B, T, F) -> (select A, (select B, T, F), F) // (select A || B, T, F) -> (select A, T, (select B, T, F)) diff --git a/llvm/test/Transforms/InstCombine/icmp-equality-test.ll b/llvm/test/Transforms/InstCombine/icmp-equality-test.ll index b146ecb97782f05..c2740ca7fe8aa9d 100644 --- a/llvm/test/Transforms/InstCombine/icmp-equality-test.ll +++ b/llvm/test/Transforms/InstCombine/icmp-equality-test.ll @@ -4,13 +4,8 @@ define i1 @icmp_equality_test(i64 %X, i64 %Y, i64 %Z) { ; CHECK-LABEL: @icmp_equality_test( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQZ:%.*]] = icmp eq i64 [[X:%.*]], [[Z:%.*]] -; CHECK-NEXT: [[YEQZ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQZ:%.*]] = xor i1 [[YEQZ]], true -; CHECK-NEXT: [[AND:%.*]] = select i1 [[NOT_YEQZ]], i1 [[XEQY]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQZ]], i1 [[YEQZ]], i1 [[AND]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] ; entry: %XeqZ = icmp eq i64 %X, %Z @@ -25,13 +20,8 @@ entry: define i1 @icmp_equality_test_constant(i42 %X, i42 %Y) { ; CHECK-LABEL: @icmp_equality_test_constant( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQC:%.*]] = icmp eq i42 [[X:%.*]], -42 -; CHECK-NEXT: [[YEQC:%.*]] = icmp eq i42 [[Y:%.*]], -42 -; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i42 [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQC:%.*]] = xor i1 [[YEQC]], true -; CHECK-NEXT: [[AND:%.*]] = select i1 [[NOT_YEQC]], i1 [[XEQY]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQC]], i1 [[YEQC]], i1 [[AND]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i42 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] ; entry: %XeqC = icmp eq i42 %X, -42 @@ -46,13 +36,8 @@ entry: define i1 @icmp_equality_test_swift_optional_pointers(i64 %X, i64 %Y) { ; CHECK-LABEL: @icmp_equality_test_swift_optional_pointers( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQC:%.*]] = icmp eq i64 [[X:%.*]], 0 -; CHECK-NEXT: [[YEQC:%.*]] = icmp eq i64 [[Y:%.*]], 0 -; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQC:%.*]] = xor i1 [[YEQC]], true -; CHECK-NEXT: [[BOTH:%.*]] = select i1 [[NOT_YEQC]], i1 [[XEQY]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQC]], i1 [[YEQC]], i1 [[BOTH]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] ; entry: %XeqC = icmp eq i64 %X, 0 @@ -67,13 +52,8 @@ entry: define <2 x i1> @icmp_equality_test_vector(<2 x i64> %X, <2 x i64> %Y) { ; CHECK-LABEL: @icmp_equality_test_vector( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQC:%.*]] = icmp eq <2 x i64> [[X:%.*]], -; CHECK-NEXT: [[YEQC:%.*]] = icmp eq <2 x i64> [[Y:%.*]], -; CHECK-NEXT: [[XEQY:%.*]] = icmp eq <2 x i64> [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQC:%.*]] = xor <2 x i1> [[YEQC]], -; CHECK-NEXT: [[AND:%.*]] = select <2 x i1> [[NOT_YEQC]], <2 x i1> [[XEQY]], <2 x i1> zeroinitializer -; CHECK-NEXT: [[EQUAL:%.*]] = select <2 x i1> [[XEQC]], <2 x i1> [[YEQC]], <2 x i1> [[AND]] -; CHECK-NEXT: ret <2 x i1> [[EQUAL]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq <2 x i64> [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret <2 x i1> [[XEQY]] ; entry: %XeqC = icmp eq <2 x i64> %X, @@ -88,13 +68,8 @@ entry: define i1 @icmp_equality_test_commute_icmp1(i64 %X, i64 %Y, i64 %Z) { ; CHECK-LABEL: @icmp_equality_test_commute_icmp1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQZ:%.*]] = icmp eq i64 [[Z:%.*]], [[X:%.*]] -; CHECK-NEXT: [[YEQZ:%.*]] = icmp eq i64 [[Z]], [[Y:%.*]] -; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[Y]], [[X]] -; CHECK-NEXT: [[NOT_YEQZ:%.*]] = xor i1 [[YEQZ]], true -; CHECK-NEXT: [[AND:%.*]] = select i1 [[NOT_YEQZ]], i1 [[XEQY]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQZ]], i1 [[YEQZ]], i1 [[AND]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] ; entry: %XeqZ = icmp eq i64 %Z, %X @@ -108,13 +83,8 @@ entry: define i1 @icmp_equality_test_commute_icmp2(i64 %X, i64 %Y, i64 %Z) { ; CHECK-LABEL: @icmp_equality_test_commute_icmp2( -; CHECK-NEXT: [[XEQZ:%.*]] = icmp eq i64 [[Z:%.*]], [[X:%.*]] -; CHECK-NEXT: [[YEQZ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[Y]], [[X]] -; CHECK-NEXT: [[NOT_YEQZ:%.*]] = xor i1 [[YEQZ]], true -; CHECK-NEXT: [[AND:%.*]] = select i1 [[NOT_YEQZ]], i1 [[XEQY]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQZ]], i1 [[YEQZ]], i1 [[AND]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] ; %XeqZ = icmp eq i64 %Z, %X %YeqZ = icmp eq i64 %Y, %Z @@ -128,13 +98,8 @@ define i1 @icmp_equality_test_commute_icmp2(i64 %X, i64 %Y, i64 %Z) { define i1 @icmp_equality_test_commute_select1(i64 %X, i64 %Y, i64 %Z) { ; CHECK-LABEL: @icmp_equality_test_commute_select1( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQZ:%.*]] = icmp eq i64 [[X:%.*]], [[Z:%.*]] -; CHECK-NEXT: [[YEQZ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQZ:%.*]] = xor i1 [[YEQZ]], true -; CHECK-NEXT: [[AND:%.*]] = select i1 [[NOT_YEQZ]], i1 [[XEQY]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQZ]], i1 [[YEQZ]], i1 [[AND]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] ; entry: %XeqZ = icmp eq i64 %X, %Z @@ -148,13 +113,8 @@ entry: define i1 @icmp_equality_test_commute_select2(i64 %X, i64 %Y, i64 %Z) { ; CHECK-LABEL: @icmp_equality_test_commute_select2( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[XEQZ_NOT:%.*]] = icmp eq i64 [[X:%.*]], [[Z:%.*]] -; CHECK-NEXT: [[YEQZ:%.*]] = icmp eq i64 [[Y:%.*]], [[Z]] -; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X]], [[Y]] -; CHECK-NEXT: [[NOT_YEQZ:%.*]] = xor i1 [[YEQZ]], true -; CHECK-NEXT: [[AND:%.*]] = select i1 [[NOT_YEQZ]], i1 [[XEQY]], i1 false -; CHECK-NEXT: [[EQUAL:%.*]] = select i1 [[XEQZ_NOT]], i1 [[YEQZ]], i1 [[AND]] -; CHECK-NEXT: ret i1 [[EQUAL]] +; CHECK-NEXT: [[XEQY:%.*]] = icmp eq i64 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: ret i1 [[XEQY]] ; entry: %XeqZ = icmp eq i64 %X, %Z