From 70f3863b5f30e856278f399b068a30bc4d5d16c2 Mon Sep 17 00:00:00 2001 From: Noah Goldstein Date: Tue, 13 Aug 2024 21:18:45 +0800 Subject: [PATCH] [DAG][PatternMatch] Add support for matchers with flags; NFC Add support for matching with `SDNodeFlags` i.e `add` with `nuw`. This patch adds helpers for `or disjoint` or `zext nneg` with the same names as we have in IR/PatternMatch api. Closes #103060 --- llvm/include/llvm/CodeGen/SDPatternMatch.h | 62 ++++++++++++++++--- llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 14 +++++ .../CodeGen/SelectionDAGPatternMatchTest.cpp | 25 +++++++- 3 files changed, 90 insertions(+), 11 deletions(-) diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h index b1aa87ca2d3e13..92efff93f60f89 100644 --- a/llvm/include/llvm/CodeGen/SDPatternMatch.h +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -514,19 +514,28 @@ struct BinaryOpc_match { unsigned Opcode; LHS_P LHS; RHS_P RHS; - - BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R) - : Opcode(Opc), LHS(L), RHS(R) {} + std::optional Flags; + BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R, + std::optional Flgs = std::nullopt) + : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {} template bool match(const MatchContext &Ctx, SDValue N) { if (sd_context_match(N, Ctx, m_Opc(Opcode))) { EffectiveOperands EO(N, Ctx); assert(EO.Size == 2); - return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) && - RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) || - (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) && - RHS.match(Ctx, N->getOperand(EO.FirstIndex))); + if (!((LHS.match(Ctx, N->getOperand(EO.FirstIndex)) && + RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) || + (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) && + RHS.match(Ctx, N->getOperand(EO.FirstIndex))))) + return false; + + if (!Flags.has_value()) + return true; + + SDNodeFlags TmpFlags = *Flags; + TmpFlags.intersectWith(N->getFlags()); + return TmpFlags == *Flags; } return false; @@ -581,6 +590,19 @@ inline BinaryOpc_match m_Or(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::OR, L, R); } +template +inline BinaryOpc_match m_DisjointOr(const LHS &L, + const RHS &R) { + SDNodeFlags Flags; + Flags.setDisjoint(true); + return BinaryOpc_match(ISD::OR, L, R, Flags); +} + +template +inline auto m_AddLike(const LHS &L, const RHS &R) { + return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R)); +} + template inline BinaryOpc_match m_Xor(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::XOR, L, R); @@ -667,15 +689,24 @@ inline BinaryOpc_match m_FRem(const LHS &L, const RHS &R) { template struct UnaryOpc_match { unsigned Opcode; Opnd_P Opnd; - - UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {} + std::optional Flags; + UnaryOpc_match(unsigned Opc, const Opnd_P &Op, + std::optional Flgs = std::nullopt) + : Opcode(Opc), Opnd(Op), Flags(Flgs) {} template bool match(const MatchContext &Ctx, SDValue N) { if (sd_context_match(N, Ctx, m_Opc(Opcode))) { EffectiveOperands EO(N, Ctx); assert(EO.Size == 1); - return Opnd.match(Ctx, N->getOperand(EO.FirstIndex)); + if (!Opnd.match(Ctx, N->getOperand(EO.FirstIndex))) + return false; + if (!Flags.has_value()) + return true; + + SDNodeFlags TmpFlags = *Flags; + TmpFlags.intersectWith(N->getFlags()); + return TmpFlags == *Flags; } return false; @@ -701,6 +732,13 @@ template inline UnaryOpc_match m_ZExt(const Opnd &Op) { return UnaryOpc_match(ISD::ZERO_EXTEND, Op); } +template +inline UnaryOpc_match m_NNegZExt(const Opnd &Op) { + SDNodeFlags Flags; + Flags.setNonNeg(true); + return UnaryOpc_match(ISD::ZERO_EXTEND, Op, Flags); +} + template inline auto m_SExt(const Opnd &Op) { return UnaryOpc_match(ISD::SIGN_EXTEND, Op); } @@ -725,6 +763,10 @@ template inline auto m_SExtOrSelf(const Opnd &Op) { return m_AnyOf(m_SExt(Op), Op); } +template inline auto m_SExtLike(const Opnd &Op) { + return m_AnyOf(m_SExt(Op), m_NNegZExt(Op)); +} + /// Match a aext or identity /// Allows to peek through optional extensions template diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index 2f36c2e86b1c3a..88549d9c9a2858 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -452,6 +452,20 @@ struct SDNodeFlags { bool hasNoFPExcept() const { return NoFPExcept; } bool hasUnpredictable() const { return Unpredictable; } + bool operator==(const SDNodeFlags &Other) const { + return NoUnsignedWrap == Other.NoUnsignedWrap && + NoSignedWrap == Other.NoSignedWrap && Exact == Other.Exact && + Disjoint == Other.Disjoint && NonNeg == Other.NonNeg && + NoNaNs == Other.NoNaNs && NoInfs == Other.NoInfs && + NoSignedZeros == Other.NoSignedZeros && + AllowReciprocal == Other.AllowReciprocal && + AllowContract == Other.AllowContract && + ApproximateFuncs == Other.ApproximateFuncs && + AllowReassociation == Other.AllowReassociation && + NoFPExcept == Other.NoFPExcept && + Unpredictable == Other.Unpredictable; + } + /// Clear any flags in this flag set that aren't also set in Flags. All /// flags will be cleared if Flags are undefined. void intersectWith(const SDNodeFlags Flags) { diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp index c04fc5621ab499..e66584b81bba25 100644 --- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -185,6 +185,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT); + SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT); SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1); SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0); @@ -192,6 +193,9 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1); SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0); SDValue Or = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1); + SDNodeFlags DisFlags; + DisFlags.setDisjoint(true); + SDValue DisOr = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op3, DisFlags); SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1); SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0); SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1); @@ -205,6 +209,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value()))); EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value()))); EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value()))); + EXPECT_TRUE(sd_match(Add, m_AddLike(m_Value(), m_Value()))); EXPECT_TRUE(sd_match( Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add))))); EXPECT_TRUE( @@ -217,6 +222,12 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value()))); EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value()))); EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(Or, m_DisjointOr(m_Value(), m_Value()))); + + EXPECT_TRUE(sd_match(DisOr, m_Or(m_Value(), m_Value()))); + EXPECT_TRUE(sd_match(DisOr, m_DisjointOr(m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(DisOr, m_Add(m_Value(), m_Value()))); + EXPECT_TRUE(sd_match(DisOr, m_AddLike(m_Value(), m_Value()))); EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value()))); EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value()))); @@ -242,9 +253,14 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) { SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT); - SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, FloatVT); + SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, FloatVT); + SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT); SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0); + SDNodeFlags NNegFlags; + NNegFlags.setNonNeg(true); + SDValue ZExtNNeg = + DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op3, NNegFlags); SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0); SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1); @@ -260,6 +276,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) { using namespace SDPatternMatch; EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value()))); EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value()))); + EXPECT_TRUE(sd_match(SExt, m_SExtLike(m_Value()))); + ASSERT_TRUE(ZExtNNeg->getFlags().hasNonNeg()); + EXPECT_FALSE(sd_match(ZExtNNeg, m_SExt(m_Value()))); + EXPECT_TRUE(sd_match(ZExtNNeg, m_NNegZExt(m_Value()))); + EXPECT_FALSE(sd_match(ZExt, m_NNegZExt(m_Value()))); + EXPECT_TRUE(sd_match(ZExtNNeg, m_SExtLike(m_Value()))); + EXPECT_FALSE(sd_match(ZExt, m_SExtLike(m_Value()))); EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1)))); EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));