-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DAG][PatternMatch] Add support for matchers with flags; NFC #103060
Conversation
Add support for matching with `SDNodeFlags` i.e `add` with `nuw`. This patch adds helpers for `or disjoint` or `zext nneg` with the same names as we have in IR/PatternMatch api.
@llvm/pr-subscribers-llvm-selectiondag Author: None (goldsteinn) ChangesAdd support for matching with This patch adds helpers for Full diff: https://github.com/llvm/llvm-project/pull/103060.diff 3 Files Affected:
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 96ece1559bc437..adeaf2fabd39e0 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -508,19 +508,28 @@ struct BinaryOpc_match {
unsigned Opcode;
LHS_P LHS;
RHS_P RHS;
-
- BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R)
- : Opcode(Opc), LHS(L), RHS(R) {}
+ std::optional<SDNodeFlags> Flags;
+ BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
+ std::optional<SDNodeFlags> Flgs = std::nullopt)
+ : Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
assert(EO.Size == 2);
- return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
- RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
- (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
- RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
+ if (!((LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
+ RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
+ (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
+ RHS.match(Ctx, N->getOperand(EO.FirstIndex)))))
+ return false;
+
+ if (!Flags.has_value())
+ return true;
+
+ SDNodeFlags TmpFlags = *Flags;
+ TmpFlags.intersectWith(N->getFlags());
+ return TmpFlags == *Flags;
}
return false;
@@ -575,6 +584,19 @@ inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
}
+template <typename LHS, typename RHS>
+inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
+ const RHS &R) {
+ SDNodeFlags Flags{};
+ Flags.setDisjoint(true);
+ return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, Flags);
+}
+
+template <typename LHS, typename RHS>
+inline auto m_AddLike(const LHS &L, const RHS &R) {
+ return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
+}
+
template <typename LHS, typename RHS>
inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
@@ -661,15 +683,24 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
unsigned Opcode;
Opnd_P Opnd;
-
- UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {}
+ std::optional<SDNodeFlags> Flags;
+ UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
+ std::optional<SDNodeFlags> Flgs = std::nullopt)
+ : Opcode(Opc), Opnd(Op), Flags(Flgs) {}
template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
assert(EO.Size == 1);
- return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
+ if (!Opnd.match(Ctx, N->getOperand(EO.FirstIndex)))
+ return false;
+ if (!Flags.has_value())
+ return true;
+
+ SDNodeFlags TmpFlags = *Flags;
+ TmpFlags.intersectWith(N->getFlags());
+ return TmpFlags == *Flags;
}
return false;
@@ -695,6 +726,13 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
}
+template <typename Opnd>
+inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
+ SDNodeFlags Flags{};
+ Flags.setNonNeg(true);
+ return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, Flags);
+}
+
template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
}
@@ -719,6 +757,10 @@ template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
return m_AnyOf(m_SExt(Op), Op);
}
+template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
+ return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
+}
+
/// Match a aext or identity
/// Allows to peek through optional extensions
template <typename Opnd>
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 2f36c2e86b1c3a..7837a5f12214bb 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -452,6 +452,20 @@ struct SDNodeFlags {
bool hasNoFPExcept() const { return NoFPExcept; }
bool hasUnpredictable() const { return Unpredictable; }
+ bool operator==(const SDNodeFlags &other) const {
+ return NoUnsignedWrap == other.NoUnsignedWrap &&
+ NoSignedWrap == other.NoSignedWrap && Exact == other.Exact &&
+ Disjoint == other.Disjoint && NonNeg == other.NonNeg &&
+ NoNaNs == other.NoNaNs && NoInfs == other.NoInfs &&
+ NoSignedZeros == other.NoSignedZeros &&
+ AllowReciprocal == other.AllowReciprocal &&
+ AllowContract == other.AllowContract &&
+ ApproximateFuncs == other.ApproximateFuncs &&
+ AllowReassociation == other.AllowReassociation &&
+ NoFPExcept == other.NoFPExcept &&
+ Unpredictable == other.Unpredictable;
+ }
+
/// Clear any flags in this flag set that aren't also set in Flags. All
/// flags will be cleared if Flags are undefined.
void intersectWith(const SDNodeFlags Flags) {
diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
index 074247e6e7d184..6db31990968afa 100644
--- a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
+++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
@@ -185,6 +185,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
+ SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);
SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
@@ -192,6 +193,9 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0);
SDValue Or = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
+ SDNodeFlags DisFlags{};
+ DisFlags.setDisjoint(true);
+ SDValue DisOr = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op3, DisFlags);
SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1);
SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0);
SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1);
@@ -205,6 +209,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(Add, m_AddLike(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
EXPECT_TRUE(
@@ -217,6 +222,12 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(Or, m_DisjointOr(m_Value(), m_Value())));
+
+ EXPECT_TRUE(sd_match(DisOr, m_Or(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(DisOr, m_DisjointOr(m_Value(), m_Value())));
+ EXPECT_FALSE(sd_match(DisOr, m_Add(m_Value(), m_Value())));
+ EXPECT_TRUE(sd_match(DisOr, m_AddLike(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value())));
@@ -241,8 +252,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
+ SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);
SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
+ SDNodeFlags NNegFlags{};
+ NNegFlags.setNonNeg(true);
+ SDValue ZExtNNeg =
+ DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op2, NNegFlags);
SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0);
SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1);
@@ -255,6 +271,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
using namespace SDPatternMatch;
EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
+ EXPECT_TRUE(sd_match(SExt, m_SExtLike(m_Value())));
+ ASSERT_TRUE(ZExtNNeg->getFlags().hasNonNeg());
+ EXPECT_FALSE(sd_match(ZExtNNeg, m_SExt(m_Value())));
+ EXPECT_TRUE(sd_match(ZExtNNeg, m_NNegZExt(m_Value())));
+ EXPECT_FALSE(sd_match(ZExt, m_NNegZExt(m_Value())));
+ EXPECT_TRUE(sd_match(ZExtNNeg, m_SExtLike(m_Value())));
+ EXPECT_FALSE(sd_match(ZExt, m_SExtLike(m_Value())));
EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));
EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));
|
@@ -452,6 +452,20 @@ struct SDNodeFlags { | |||
bool hasNoFPExcept() const { return NoFPExcept; } | |||
bool hasUnpredictable() const { return Unpredictable; } | |||
|
|||
bool operator==(const SDNodeFlags &other) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool operator==(const SDNodeFlags &other) const { | |
bool operator==(const SDNodeFlags &Other) const { |
@@ -695,6 +726,13 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) { | |||
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op); | |||
} | |||
|
|||
template <typename Opnd> | |||
inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) { | |||
SDNodeFlags Flags{}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: IIUC since SDNodeFlags has a user-provided default ctor, the value initialization here will eventually fallback to default initialization, which means it's the same as SDNodeFlags Flags;
. Maybe we can use the latter instead.
ditto for other occurrences in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/ minor comments. Thanks!
|
||
SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0); | ||
SDNodeFlags NNegFlags{}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you missed this.
|
||
SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1); | ||
SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0); | ||
SDValue Mul = DAG->getNode(ISD::MUL, DL, Int32VT, Add, Sub); | ||
SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1); | ||
SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0); | ||
SDValue Or = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1); | ||
SDNodeFlags DisFlags{}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Add support for matching with
SDNodeFlags
i.eadd
withnuw
.This patch adds helpers for
or disjoint
orzext nneg
with the samenames as we have in IR/PatternMatch api.