-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR][PatternMatch] Make m_Checked{Int,Fp}
accept Constant *
output instead of APInt *
#91377
[IR][PatternMatch] Make m_Checked{Int,Fp}
accept Constant *
output instead of APInt *
#91377
Conversation
@llvm/pr-subscribers-llvm-ir Author: None (goldsteinn) ChangesThe Leaving other helpers such as Full diff: https://github.com/llvm/llvm-project/pull/91377.diff 2 Files Affected:
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 5d8f5c134bb5b..171ddab977dea 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -354,7 +354,8 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
/// is true.
template <typename Predicate, typename ConstantVal, bool AllowPoison>
struct cstval_pred_ty : public Predicate {
- template <typename ITy> bool match(ITy *V) {
+ const Constant **Res = nullptr;
+ template <typename ITy> bool match_impl(ITy *V) {
if (const auto *CV = dyn_cast<ConstantVal>(V))
return this->isValue(CV->getValue());
if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
@@ -387,6 +388,15 @@ struct cstval_pred_ty : public Predicate {
}
return false;
}
+
+ template <typename ITy> bool match(ITy *V) {
+ if (this->match_impl(V)) {
+ if (Res)
+ *Res = cast<Constant>(V);
+ return true;
+ }
+ return false;
+ }
};
/// specialization of cstval_pred_ty for ConstantInt
@@ -469,28 +479,24 @@ template <typename APTy> struct custom_checkfn {
/// For vectors, poison elements are assumed to match.
inline cst_pred_ty<custom_checkfn<APInt>>
m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
- return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+ return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}};
}
-inline api_pred_ty<custom_checkfn<APInt>>
-m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
- api_pred_ty<custom_checkfn<APInt>> P(V);
- P.CheckFn = CheckFn;
- return P;
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedInt(const Constant *&V, function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}, &V};
}
/// Match a float or vector where CheckFn(ele) for each element is true.
/// For vectors, poison elements are assumed to match.
inline cstfp_pred_ty<custom_checkfn<APFloat>>
m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
- return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+ return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}};
}
-inline apf_pred_ty<custom_checkfn<APFloat>>
-m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
- apf_pred_ty<custom_checkfn<APFloat>> P(V);
- P.CheckFn = CheckFn;
- return P;
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFp(const Constant *&V, function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}, &V};
}
struct is_any_apint {
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index d5a4a6a05687d..6e79d5cd8ed13 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -614,7 +614,7 @@ TEST_F(PatternMatchTest, BitCast) {
TEST_F(PatternMatchTest, CheckedInt) {
Type *I8Ty = IRB.getInt8Ty();
const APInt *Res = nullptr;
-
+ const Constant * CRes = nullptr;
auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
auto CheckTrue = [](const APInt &) { return true; };
auto CheckFalse = [](const APInt &) { return false; };
@@ -625,39 +625,49 @@ TEST_F(PatternMatchTest, CheckedInt) {
APInt APVal(8, Val);
Constant *C = ConstantInt::get(I8Ty, Val);
+ CRes = nullptr;
Res = nullptr;
EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
- EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+ EXPECT_TRUE(m_CheckedInt(CRes, CheckTrue).match(C));
+ EXPECT_NE(CRes, nullptr);
+ EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
EXPECT_EQ(*Res, APVal);
+ CRes = nullptr;
Res = nullptr;
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
- EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+ EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(C));
+ EXPECT_EQ(CRes, nullptr);
+ CRes = nullptr;
Res = nullptr;
EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
- EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CRes, CheckUgt1).match(C));
if (CheckUgt1(APVal)) {
- EXPECT_NE(Res, nullptr);
+ EXPECT_NE(CRes, nullptr);
+ EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
EXPECT_EQ(*Res, APVal);
}
+ CRes = nullptr;
Res = nullptr;
EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
- EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CRes, CheckNonZero).match(C));
if (CheckNonZero(APVal)) {
- EXPECT_NE(Res, nullptr);
+ EXPECT_NE(CRes, nullptr);
+ EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
EXPECT_EQ(*Res, APVal);
}
+ CRes = nullptr;
Res = nullptr;
EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
- EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CRes, CheckPow2).match(C));
if (CheckPow2(APVal)) {
- EXPECT_NE(Res, nullptr);
+ EXPECT_NE(CRes, nullptr);
+ EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
EXPECT_EQ(*Res, APVal);
}
-
};
DoScalarCheck(0);
@@ -666,20 +676,20 @@ TEST_F(PatternMatchTest, CheckedInt) {
DoScalarCheck(3);
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
- EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
- EXPECT_EQ(Res, nullptr);
+ EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(CRes, nullptr);
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
- EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
- EXPECT_EQ(Res, nullptr);
+ EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(CRes, nullptr);
EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
- EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
- EXPECT_EQ(Res, nullptr);
+ EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(CRes, nullptr);
EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
- EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
- EXPECT_EQ(Res, nullptr);
+ EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(CRes, nullptr);
auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
function_ref<bool(const APInt &)> CheckFn,
@@ -711,13 +721,16 @@ TEST_F(PatternMatchTest, CheckedInt) {
EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
m_CheckedInt(CheckFn).match(C));
+ CRes = nullptr;
Res = nullptr;
bool Expec =
- !(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
- EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+ !(HasUndef && !UndefAsPoison) && Okay.value_or(false);
+ EXPECT_EQ(Expec, m_CheckedInt(CRes, CheckFn).match(C));
if (Expec) {
- EXPECT_NE(Res, nullptr);
- EXPECT_EQ(*Res, *First);
+ EXPECT_NE(CRes, nullptr);
+ EXPECT_EQ(match(CRes, m_APIntAllowPoison(Res)), AllSame);
+ if (AllSame)
+ EXPECT_EQ(*Res, *First);
}
};
auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
@@ -1559,24 +1572,25 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));
const APFloat *C;
+ const Constant *CC;
// Regardless of whether poison is allowed,
// a fully undef/poison constant does not match.
EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
- EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CC, CheckTrue)));
EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
- EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CC, CheckTrue)));
EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
- EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
+ EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(CC, CheckTrue)));
EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
- EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));
+ EXPECT_FALSE(match(VectorPoison, m_CheckedFp(CC, CheckTrue)));
// We can always match simple constants and simple splats.
C = nullptr;
@@ -1597,12 +1611,13 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
C = nullptr;
EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
EXPECT_TRUE(C->isZero());
- C = nullptr;
- EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
- EXPECT_TRUE(C->isZero());
- C = nullptr;
- EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
- EXPECT_TRUE(C->isZero());
+
+ CC = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckTrue)));
+ EXPECT_TRUE(CC->isNullValue());
+ CC = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckNonNaN)));
+ EXPECT_TRUE(CC->isNullValue());
// Splats with undef are never allowed.
// Whether splats with poison can be matched depends on the matcher.
@@ -1627,11 +1642,17 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
C = nullptr;
EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
EXPECT_TRUE(C->isZero());
+ CC = nullptr;
C = nullptr;
- EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
+ EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckTrue)));
+ EXPECT_NE(CC, nullptr);
+ EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
EXPECT_TRUE(C->isZero());
+ CC = nullptr;
C = nullptr;
- EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
+ EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckNonNaN)));
+ EXPECT_NE(CC, nullptr);
+ EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
EXPECT_TRUE(C->isZero());
}
|
You can test this locally with the following command:git-clang-format --diff 45fed80b15df85cee53d3d31a7a46ae0daa91a3f 5e4f00d8a82cbaf37e32ebf6d9bfd1058115e05c -- llvm/include/llvm/IR/PatternMatch.h llvm/unittests/IR/PatternMatch.cpp View the diff from clang-format here.diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 9f91b4f3f9..6a035553f7 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -613,7 +613,7 @@ TEST_F(PatternMatchTest, BitCast) {
TEST_F(PatternMatchTest, CheckedInt) {
Type *I8Ty = IRB.getInt8Ty();
- const Constant * CRes = nullptr;
+ const Constant *CRes = nullptr;
auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
auto CheckTrue = [](const APInt &) { return true; };
auto CheckFalse = [](const APInt &) { return false; };
|
llvm/unittests/IR/PatternMatch.cpp
Outdated
EXPECT_NE(CRes, nullptr); | ||
EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
EXPECT_NE(CRes, nullptr); | |
EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res))); | |
EXPECT_EQ(CRes, C); |
I think we should be checking the constant directly, not the APInt it contains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, how it is now was mostly based on the simplest changes...
…t instead of `APInt *` The `APInt *` version is pretty useless as any case one needs an `APInt *` out, they could just replace whatever they have the `m_Checked...` lambda with direct checks on the `APInt`. Leaving other helpers such as `m_Negative`, `m_Power2`, etc... unchanged as the `APInt` out version is used mostly for convenience and rarely change functionality when converted output a `Constant *`.
98fe785
to
5e4f00d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/ formatting fixed.
The
APInt *
version is pretty useless as any case one needs anAPInt *
out, they could just replace whatever they have them_Checked...
lambda with direct checks on theAPInt
.Leaving other helpers such as
m_Negative
,m_Power2
,etc... unchanged as the
APInt
out version is used mostly forconvenience and rarely change functionality when converted output a
Constant *
.