Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DAG][PatternMatch] Add support for matchers with flags; NFC #103060

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 52 additions & 10 deletions llvm/include/llvm/CodeGen/SDPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,19 +508,28 @@ struct BinaryOpc_match {
unsigned Opcode;
LHS_P LHS;
RHS_P RHS;

BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R)
: Opcode(Opc), LHS(L), RHS(R) {}
std::optional<SDNodeFlags> Flags;
BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R,
std::optional<SDNodeFlags> Flgs = std::nullopt)
: Opcode(Opc), LHS(L), RHS(R), Flags(Flgs) {}

template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
assert(EO.Size == 2);
return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
(Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
RHS.match(Ctx, N->getOperand(EO.FirstIndex)));
if (!((LHS.match(Ctx, N->getOperand(EO.FirstIndex)) &&
RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) ||
(Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) &&
RHS.match(Ctx, N->getOperand(EO.FirstIndex)))))
return false;

if (!Flags.has_value())
return true;

SDNodeFlags TmpFlags = *Flags;
TmpFlags.intersectWith(N->getFlags());
return TmpFlags == *Flags;
}

return false;
Expand Down Expand Up @@ -575,6 +584,19 @@ inline BinaryOpc_match<LHS, RHS, true> m_Or(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R);
}

template <typename LHS, typename RHS>
inline BinaryOpc_match<LHS, RHS, true> m_DisjointOr(const LHS &L,
const RHS &R) {
SDNodeFlags Flags;
Flags.setDisjoint(true);
return BinaryOpc_match<LHS, RHS, true>(ISD::OR, L, R, Flags);
}

template <typename LHS, typename RHS>
inline auto m_AddLike(const LHS &L, const RHS &R) {
return m_AnyOf(m_Add(L, R), m_DisjointOr(L, R));
}

template <typename LHS, typename RHS>
inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
return BinaryOpc_match<LHS, RHS, true>(ISD::XOR, L, R);
Expand Down Expand Up @@ -661,15 +683,24 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match {
unsigned Opcode;
Opnd_P Opnd;

UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {}
std::optional<SDNodeFlags> Flags;
UnaryOpc_match(unsigned Opc, const Opnd_P &Op,
std::optional<SDNodeFlags> Flgs = std::nullopt)
: Opcode(Opc), Opnd(Op), Flags(Flgs) {}

template <typename MatchContext>
bool match(const MatchContext &Ctx, SDValue N) {
if (sd_context_match(N, Ctx, m_Opc(Opcode))) {
EffectiveOperands<ExcludeChain> EO(N);
assert(EO.Size == 1);
return Opnd.match(Ctx, N->getOperand(EO.FirstIndex));
if (!Opnd.match(Ctx, N->getOperand(EO.FirstIndex)))
return false;
if (!Flags.has_value())
return true;

SDNodeFlags TmpFlags = *Flags;
TmpFlags.intersectWith(N->getFlags());
return TmpFlags == *Flags;
}

return false;
Expand All @@ -695,6 +726,13 @@ template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op);
}

template <typename Opnd>
inline UnaryOpc_match<Opnd> m_NNegZExt(const Opnd &Op) {
SDNodeFlags Flags;
Flags.setNonNeg(true);
return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op, Flags);
}

template <typename Opnd> inline auto m_SExt(const Opnd &Op) {
return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op);
}
Expand All @@ -719,6 +757,10 @@ template <typename Opnd> inline auto m_SExtOrSelf(const Opnd &Op) {
return m_AnyOf(m_SExt(Op), Op);
}

template <typename Opnd> inline auto m_SExtLike(const Opnd &Op) {
return m_AnyOf(m_SExt(Op), m_NNegZExt(Op));
}

/// Match a aext or identity
/// Allows to peek through optional extensions
template <typename Opnd>
Expand Down
14 changes: 14 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,20 @@ struct SDNodeFlags {
bool hasNoFPExcept() const { return NoFPExcept; }
bool hasUnpredictable() const { return Unpredictable; }

bool operator==(const SDNodeFlags &Other) const {
return NoUnsignedWrap == Other.NoUnsignedWrap &&
NoSignedWrap == Other.NoSignedWrap && Exact == Other.Exact &&
Disjoint == Other.Disjoint && NonNeg == Other.NonNeg &&
NoNaNs == Other.NoNaNs && NoInfs == Other.NoInfs &&
NoSignedZeros == Other.NoSignedZeros &&
AllowReciprocal == Other.AllowReciprocal &&
AllowContract == Other.AllowContract &&
ApproximateFuncs == Other.ApproximateFuncs &&
AllowReassociation == Other.AllowReassociation &&
NoFPExcept == Other.NoFPExcept &&
Unpredictable == Other.Unpredictable;
}

/// Clear any flags in this flag set that aren't also set in Flags. All
/// flags will be cleared if Flags are undefined.
void intersectWith(const SDNodeFlags Flags) {
Expand Down
23 changes: 23 additions & 0 deletions llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,13 +185,17 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT);
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT);
SDValue Op3 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 8, Int32VT);

SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1);
SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0);
SDValue Mul = DAG->getNode(ISD::MUL, DL, Int32VT, Add, Sub);
SDValue And = DAG->getNode(ISD::AND, DL, Int32VT, Op0, Op1);
SDValue Xor = DAG->getNode(ISD::XOR, DL, Int32VT, Op1, Op0);
SDValue Or = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op1);
SDNodeFlags DisFlags{};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

DisFlags.setDisjoint(true);
SDValue DisOr = DAG->getNode(ISD::OR, DL, Int32VT, Op0, Op3, DisFlags);
SDValue SMax = DAG->getNode(ISD::SMAX, DL, Int32VT, Op0, Op1);
SDValue SMin = DAG->getNode(ISD::SMIN, DL, Int32VT, Op1, Op0);
SDValue UMax = DAG->getNode(ISD::UMAX, DL, Int32VT, Op0, Op1);
Expand All @@ -205,6 +209,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Add, m_AddLike(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(
Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add)))));
EXPECT_TRUE(
Expand All @@ -217,6 +222,12 @@ TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) {
EXPECT_TRUE(sd_match(Xor, m_Xor(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Or, m_c_BinOp(ISD::OR, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(Or, m_Or(m_Value(), m_Value())));
EXPECT_FALSE(sd_match(Or, m_DisjointOr(m_Value(), m_Value())));

EXPECT_TRUE(sd_match(DisOr, m_Or(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(DisOr, m_DisjointOr(m_Value(), m_Value())));
EXPECT_FALSE(sd_match(DisOr, m_Add(m_Value(), m_Value())));
EXPECT_TRUE(sd_match(DisOr, m_AddLike(m_Value(), m_Value())));

EXPECT_TRUE(sd_match(SMax, m_c_BinOp(ISD::SMAX, m_Value(), m_Value())));
EXPECT_TRUE(sd_match(SMax, m_SMax(m_Value(), m_Value())));
Expand All @@ -241,8 +252,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {

SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT);
SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT);
SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Int32VT);

SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0);
SDNodeFlags NNegFlags{};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you missed this.

NNegFlags.setNonNeg(true);
SDValue ZExtNNeg =
DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op2, NNegFlags);
SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0);
SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1);

Expand All @@ -255,6 +271,13 @@ TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) {
using namespace SDPatternMatch;
EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value())));
EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value())));
EXPECT_TRUE(sd_match(SExt, m_SExtLike(m_Value())));
ASSERT_TRUE(ZExtNNeg->getFlags().hasNonNeg());
EXPECT_FALSE(sd_match(ZExtNNeg, m_SExt(m_Value())));
EXPECT_TRUE(sd_match(ZExtNNeg, m_NNegZExt(m_Value())));
EXPECT_FALSE(sd_match(ZExt, m_NNegZExt(m_Value())));
EXPECT_TRUE(sd_match(ZExtNNeg, m_SExtLike(m_Value())));
EXPECT_FALSE(sd_match(ZExt, m_SExtLike(m_Value())));
EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1))));

EXPECT_TRUE(sd_match(Neg, m_Neg(m_Value())));
Expand Down
Loading