Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DAG][PatternMatch] Add support for matchers with flags; NFC #103060

Closed
wants to merge 2 commits into from

Conversation

goldsteinn
Copy link
Contributor

Add support for matching with SDNodeFlags i.e add with nuw.

This patch adds helpers for or disjoint or zext nneg with the same
names as we have in IR/PatternMatch api.

Add support for matching with `SDNodeFlags` i.e `add` with `nuw`.

This patch adds helpers for `or disjoint` or `zext nneg` with the same
names as we have in IR/PatternMatch api.
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Aug 13, 2024
@goldsteinn goldsteinn requested review from mshockwave, RKSimon and marcauberer and removed request for RKSimon and mshockwave August 13, 2024 13:20
@llvmbot
Copy link
Collaborator

llvmbot commented Aug 13, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: None (goldsteinn)

Changes

Add support for matching with SDNodeFlags i.e add with nuw.

This patch adds helpers for or disjoint or zext nneg with the same
names as we have in IR/PatternMatch api.


Full diff: https://github.com/llvm/llvm-project/pull/103060.diff

3 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SDPatternMatch.h (+52-10)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+14)
  • (modified) llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp (+23)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96ece1559bc437..adeaf2fabd39e0 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -508,19 +508,28 @@ struct BinaryOpc_match {
   unsigned Opcode;
   LHS_P LHS;
   RHS_P RHS;
-
-  BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R)
-      : Opcode(Opc), LHS(L), RHS(R) {}
+  std::optional<SDNodeFlags> Flags;
+  BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
+                  std::optional<SDNodeFlags> Flgs = std::nullopt)
+      : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
       EffectiveOperands<ExcludeChain> EO(N);
       assert(EO.Size == 2);
-      return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
-              RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
-             (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
-              RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
+      if (!((LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+             RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+            (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+             RHS.match(Ctx, N->getOperand(EO.FirstIndex)))))
+        return false;
+
+      if (!Flags.has_value())
+        return true;
+
+      SDNodeFlags TmpFlags = *Flags;
+      TmpFlags.intersectWith(N->getFlags());
+      return TmpFlags == *Flags;
     }
 
     return false;
@@ -575,6 +584,19 @@ inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
   return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
 }
 
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
+                                                    const RHS &R) {
+  SDNodeFlags Flags{};
+  Flags.setDisjoint(true);
+  return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, Flags);
+}
+
+template <typename LHS, typename RHS>
+inline auto m_AddLike(const LHS &L, const RHS &R) {
+  return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
+}
+
 template <typename LHS, typename RHS>
 inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
   return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
@@ -661,15 +683,24 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
 template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
   unsigned Opcode;
   Opnd_P Opnd;
-
-  UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {}
+  std::optional<SDNodeFlags> Flags;
+  UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
+                 std::optional<SDNodeFlags> Flgs = std::nullopt)
+      : Opcode(Opc), Opnd(Op), Flags(Flgs) {}
 
   template <typename MatchContext>
   bool match(const MatchContext &Ctx, SDValue N) {
     if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
       EffectiveOperands<ExcludeChain> EO(N);
       assert(EO.Size == 1);
-      return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
+      if (!Opnd.match(Ctx, N->getOperand(EO.FirstIndex)))
+        return false;
+      if (!Flags.has_value())
+        return true;
+
+      SDNodeFlags TmpFlags = *Flags;
+      TmpFlags.intersectWith(N->getFlags());
+      return TmpFlags == *Flags;
     }
 
     return false;
@@ -695,6 +726,13 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
   return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
 }
 
+template <typename Opnd>
+inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
+  SDNodeFlags Flags{};
+  Flags.setNonNeg(true);
+  return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, Flags);
+}
+
 template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
   return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
 }
@@ -719,6 +757,10 @@ template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
   return m_AnyOf(m_SExt(Op), Op);
 }
 
+template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
+  return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
+}
+
 /// Match a aext or identity
 /// Allows to peek through optional extensions
 template <typename Opnd>
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 2f36c2e86b1c3a..7837a5f12214bb 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -452,6 +452,20 @@ struct SDNodeFlags {
   bool hasNoFPExcept() const { return NoFPExcept; }
   bool hasUnpredictable() const { return Unpredictable; }
 
+  bool operator==(const SDNodeFlags &other) const {
+    return NoUnsignedWrap == other.NoUnsignedWrap &&
+           NoSignedWrap == other.NoSignedWrap && Exact == other.Exact &&
+           Disjoint == other.Disjoint && NonNeg == other.NonNeg &&
+           NoNaNs == other.NoNaNs && NoInfs == other.NoInfs &&
+           NoSignedZeros == other.NoSignedZeros &&
+           AllowReciprocal == other.AllowReciprocal &&
+           AllowContract == other.AllowContract &&
+           ApproximateFuncs == other.ApproximateFuncs &&
+           AllowReassociation == other.AllowReassociation &&
+           NoFPExcept == other.NoFPExcept &&
+           Unpredictable == other.Unpredictable;
+  }
+
   /// Clear any flags in this flag set that aren't also set in Flags. All
   /// flags will be cleared if Flags are undefined.
   void intersectWith(const SDNodeFlags Flags) {
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 074247e6e7d184..6db31990968afa 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -185,6 +185,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
   SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
+  SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
 
   SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
   SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
@@ -192,6 +193,9 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
   SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0);
   SDValue Or  = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
+  SDNodeFlags DisFlags{};
+  DisFlags.setDisjoint(true);
+  SDValue DisOr = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op3, DisFlags);
   SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1);
   SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0);
   SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1);
@@ -205,6 +209,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(Add, m_AddLike(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(
       Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
   EXPECT_TRUE(
@@ -217,6 +222,12 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
   EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(Or, m_DisjointOr(m_Value(), m_Value())));
+
+  EXPECT_TRUE(sd_match(DisOr, m_Or(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(DisOr, m_DisjointOr(m_Value(), m_Value())));
+  EXPECT_FALSE(sd_match(DisOr, m_Add(m_Value(), m_Value())));
+  EXPECT_TRUE(sd_match(DisOr, m_AddLike(m_Value(), m_Value())));
 
   EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value())));
   EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value())));
@@ -241,8 +252,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
 
   SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
   SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
+  SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
 
   SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
+  SDNodeFlags NNegFlags{};
+  NNegFlags.setNonNeg(true);
+  SDValue ZExtNNeg =
+      DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op2, NNegFlags);
   SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0);
   SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1);
 
@@ -255,6 +271,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
   using namespace SDPatternMatch;
   EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
   EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
+  EXPECT_TRUE(sd_match(SExt, m_SExtLike(m_Value())));
+  ASSERT_TRUE(ZExtNNeg->getFlags().hasNonNeg());
+  EXPECT_FALSE(sd_match(ZExtNNeg, m_SExt(m_Value())));
+  EXPECT_TRUE(sd_match(ZExtNNeg, m_NNegZExt(m_Value())));
+  EXPECT_FALSE(sd_match(ZExt, m_NNegZExt(m_Value())));
+  EXPECT_TRUE(sd_match(ZExtNNeg, m_SExtLike(m_Value())));
+  EXPECT_FALSE(sd_match(ZExt, m_SExtLike(m_Value())));
   EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));
 
   EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));

@@ -452,6 +452,20 @@ struct SDNodeFlags {
bool hasNoFPExcept() const { return NoFPExcept; }
bool hasUnpredictable() const { return Unpredictable; }

bool operator==(const SDNodeFlags &other) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool operator==(const SDNodeFlags &other) const {
bool operator==(const SDNodeFlags &Other) const {

@@ -695,6 +726,13 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
}

template <typename Opnd>
inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
SDNodeFlags Flags{};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: IIUC since SDNodeFlags has a user-provided default ctor, the value initialization here will eventually fallback to default initialization, which means it's the same as SDNodeFlags Flags;. Maybe we can use the latter instead.

ditto for other occurrences in this PR.

Copy link
Member

@marcauberer marcauberer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Member

@mshockwave mshockwave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ minor comments. Thanks!


SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
SDNodeFlags NNegFlags{};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you missed this.


SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
SDValue Mul = DAG->getNode(ISD::MUL, DL, Int32VT, Add, Sub);
SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0);
SDValue Or = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
SDNodeFlags DisFlags{};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants