Skip to content

Commit

Permalink
[IR][PatternMatch] Make m_Checked{Int,Fp} accept Constant * outpu…
Browse files Browse the repository at this point in the history
…t instead of `APInt *`

The `APInt *` version is pretty useless as any case one needs an
`APInt *` out, they could just replace whatever they have the
`m_Checked...` lambda with direct checks on the `APInt`.

Leaving other helpers such as `m_Negative`, `m_Power2`,
etc... unchanged as the `APInt` out version is used mostly for
convenience and rarely change functionality when converted output a
`Constant *`.

Closes #91377
  • Loading branch information
goldsteinn committed May 11, 2024
1 parent 38b2755 commit 11cb3c3
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 64 deletions.
32 changes: 19 additions & 13 deletions llvm/include/llvm/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
/// is true.
template <typename Predicate, typename ConstantVal, bool AllowPoison>
struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
const Constant **Res = nullptr;
template <typename ITy> bool match_impl(ITy *V) {
if (const auto *CV = dyn_cast<ConstantVal>(V))
return this->isValue(CV->getValue());
if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
Expand Down Expand Up @@ -387,6 +388,15 @@ struct cstval_pred_ty : public Predicate {
}
return false;
}

template <typename ITy> bool match(ITy *V) {
if (this->match_impl(V)) {
if (Res)
*Res = cast<Constant>(V);
return true;
}
return false;
}
};

/// specialization of cstval_pred_ty for ConstantInt
Expand Down Expand Up @@ -469,28 +479,24 @@ template <typename APTy> struct custom_checkfn {
/// For vectors, poison elements are assumed to match.
inline cst_pred_ty<custom_checkfn<APInt>>
m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}};
}

inline api_pred_ty<custom_checkfn<APInt>>
m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
api_pred_ty<custom_checkfn<APInt>> P(V);
P.CheckFn = CheckFn;
return P;
inline cst_pred_ty<custom_checkfn<APInt>>
m_CheckedInt(const Constant *&V, function_ref<bool(const APInt &)> CheckFn) {
return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}, &V};
}

/// Match a float or vector where CheckFn(ele) for each element is true.
/// For vectors, poison elements are assumed to match.
inline cstfp_pred_ty<custom_checkfn<APFloat>>
m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}};
}

inline apf_pred_ty<custom_checkfn<APFloat>>
m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
apf_pred_ty<custom_checkfn<APFloat>> P(V);
P.CheckFn = CheckFn;
return P;
inline cstfp_pred_ty<custom_checkfn<APFloat>>
m_CheckedFp(const Constant *&V, function_ref<bool(const APFloat &)> CheckFn) {
return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}, &V};
}

struct is_any_apint {
Expand Down
104 changes: 53 additions & 51 deletions llvm/unittests/IR/PatternMatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,7 @@ TEST_F(PatternMatchTest, BitCast) {

TEST_F(PatternMatchTest, CheckedInt) {
Type *I8Ty = IRB.getInt8Ty();
const APInt *Res = nullptr;

const Constant * CRes = nullptr;
auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
auto CheckTrue = [](const APInt &) { return true; };
auto CheckFalse = [](const APInt &) { return false; };
Expand All @@ -625,38 +624,33 @@ TEST_F(PatternMatchTest, CheckedInt) {
APInt APVal(8, Val);
Constant *C = ConstantInt::get(I8Ty, Val);

Res = nullptr;
CRes = nullptr;
EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
EXPECT_EQ(*Res, APVal);
EXPECT_TRUE(m_CheckedInt(CRes, CheckTrue).match(C));
EXPECT_EQ(CRes, C);

Res = nullptr;
CRes = nullptr;
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(C));
EXPECT_EQ(CRes, nullptr);

Res = nullptr;
CRes = nullptr;
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
if (CheckUgt1(APVal)) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, APVal);
}
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CRes, CheckUgt1).match(C));
if (CheckUgt1(APVal))
EXPECT_EQ(CRes, C);

Res = nullptr;
CRes = nullptr;
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
if (CheckNonZero(APVal)) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, APVal);
}
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CRes, CheckNonZero).match(C));
if (CheckNonZero(APVal))
EXPECT_EQ(CRes, C);

Res = nullptr;
CRes = nullptr;
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
if (CheckPow2(APVal)) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, APVal);
}
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CRes, CheckPow2).match(C));
if (CheckPow2(APVal))
EXPECT_EQ(CRes, C);

};

Expand All @@ -666,20 +660,20 @@ TEST_F(PatternMatchTest, CheckedInt) {
DoScalarCheck(3);

EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
EXPECT_EQ(Res, nullptr);
EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(UndefValue::get(I8Ty)));
EXPECT_EQ(CRes, nullptr);

EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
EXPECT_EQ(Res, nullptr);
EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(UndefValue::get(I8Ty)));
EXPECT_EQ(CRes, nullptr);

EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
EXPECT_EQ(Res, nullptr);
EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(PoisonValue::get(I8Ty)));
EXPECT_EQ(CRes, nullptr);

EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
EXPECT_EQ(Res, nullptr);
EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(PoisonValue::get(I8Ty)));
EXPECT_EQ(CRes, nullptr);

auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
function_ref<bool(const APInt &)> CheckFn,
Expand Down Expand Up @@ -711,13 +705,13 @@ TEST_F(PatternMatchTest, CheckedInt) {
EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
m_CheckedInt(CheckFn).match(C));

Res = nullptr;
bool Expec =
!(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
CRes = nullptr;
bool Expec = !(HasUndef && !UndefAsPoison) && Okay.value_or(false);
EXPECT_EQ(Expec, m_CheckedInt(CRes, CheckFn).match(C));
if (Expec) {
EXPECT_NE(Res, nullptr);
EXPECT_EQ(*Res, *First);
EXPECT_NE(CRes, nullptr);
if (AllSame)
EXPECT_EQ(CRes, C);
}
};
auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
Expand Down Expand Up @@ -1559,24 +1553,25 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));

const APFloat *C;
const Constant *CC;
// Regardless of whether poison is allowed,
// a fully undef/poison constant does not match.
EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CC, CheckTrue)));
EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CC, CheckTrue)));
EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(CC, CheckTrue)));
EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));
EXPECT_FALSE(match(VectorPoison, m_CheckedFp(CC, CheckTrue)));

// We can always match simple constants and simple splats.
C = nullptr;
Expand All @@ -1597,12 +1592,13 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
C = nullptr;
EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
EXPECT_TRUE(C->isZero());
C = nullptr;
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
EXPECT_TRUE(C->isZero());
C = nullptr;
EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
EXPECT_TRUE(C->isZero());

CC = nullptr;
EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckTrue)));
EXPECT_TRUE(CC->isNullValue());
CC = nullptr;
EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckNonNaN)));
EXPECT_TRUE(CC->isNullValue());

// Splats with undef are never allowed.
// Whether splats with poison can be matched depends on the matcher.
Expand All @@ -1627,11 +1623,17 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
C = nullptr;
EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
EXPECT_TRUE(C->isZero());
CC = nullptr;
C = nullptr;
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckTrue)));
EXPECT_NE(CC, nullptr);
EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
EXPECT_TRUE(C->isZero());
CC = nullptr;
C = nullptr;
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckNonNaN)));
EXPECT_NE(CC, nullptr);
EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
EXPECT_TRUE(C->isZero());
}

Expand Down

0 comments on commit 11cb3c3

Please sign in to comment.