Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR][PatternMatch] Make m_Checked{Int,Fp} accept Constant * output instead of APInt * #91377

Closed

Conversation

goldsteinn
Copy link
Contributor

The APInt * version is pretty useless as any case one needs an
APInt * out, they could just replace whatever they have the
m_Checked... lambda with direct checks on the APInt.

Leaving other helpers such as m_Negative, m_Power2,
etc... unchanged as the APInt out version is used mostly for
convenience and rarely change functionality when converted output a
Constant *.

@llvmbot llvmbot added the llvm:ir label May 7, 2024
@llvmbot
Copy link
Member

llvmbot commented May 7, 2024

@llvm/pr-subscribers-llvm-ir

Author: None (goldsteinn)

Changes

The APInt * version is pretty useless as any case one needs an
APInt * out, they could just replace whatever they have the
m_Checked... lambda with direct checks on the APInt.

Leaving other helpers such as m_Negative, m_Power2,
etc... unchanged as the APInt out version is used mostly for
convenience and rarely change functionality when converted output a
Constant *.


Full diff: https://github.com/llvm/llvm-project/pull/91377.diff

2 Files Affected:

  • (modified) llvm/include/llvm/IR/PatternMatch.h (+19-13)
  • (modified) llvm/unittests/IR/PatternMatch.cpp (+55-34)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 5d8f5c134bb5b..171ddab977dea 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -354,7 +354,8 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
 /// is true.
 template <typename Predicate, typename ConstantVal, bool AllowPoison>
 struct cstval_pred_ty : public Predicate {
-  template <typename ITy> bool match(ITy *V) {
+  const Constant **Res = nullptr;
+  template <typename ITy> bool match_impl(ITy *V) {
     if (const auto *CV = dyn_cast<ConstantVal>(V))
       return this->isValue(CV->getValue());
     if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
@@ -387,6 +388,15 @@ struct cstval_pred_ty : public Predicate {
     }
     return false;
   }
+
+  template <typename ITy> bool match(ITy *V) {
+    if (this->match_impl(V)) {
+      if (Res)
+        *Res = cast<Constant>(V);
+      return true;
+    }
+    return false;
+  }
 };
 
 /// specialization of cstval_pred_ty for ConstantInt
@@ -469,28 +479,24 @@ template <typename APTy> struct custom_checkfn {
 /// For vectors, poison elements are assumed to match.
 inline cst_pred_ty<custom_checkfn<APInt>>
 m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
-  return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+  return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}};
 }
 
-inline api_pred_ty<custom_checkfn<APInt>>
-m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
-  api_pred_ty<custom_checkfn<APInt>> P(V);
-  P.CheckFn = CheckFn;
-  return P;
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedInt(const Constant *&V, function_ref<bool(const APInt &)> CheckFn) {
+  return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}, &V};
 }
 
 /// Match a float or vector where CheckFn(ele) for each element is true.
 /// For vectors, poison elements are assumed to match.
 inline cstfp_pred_ty<custom_checkfn<APFloat>>
 m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
-  return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+  return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}};
 }
 
-inline apf_pred_ty<custom_checkfn<APFloat>>
-m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
-  apf_pred_ty<custom_checkfn<APFloat>> P(V);
-  P.CheckFn = CheckFn;
-  return P;
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFp(const Constant *&V, function_ref<bool(const APFloat &)> CheckFn) {
+  return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}, &V};
 }
 
 struct is_any_apint {
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index d5a4a6a05687d..6e79d5cd8ed13 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -614,7 +614,7 @@ TEST_F(PatternMatchTest, BitCast) {
 TEST_F(PatternMatchTest, CheckedInt) {
   Type *I8Ty = IRB.getInt8Ty();
   const APInt *Res = nullptr;
-
+  const Constant * CRes = nullptr;
   auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
   auto CheckTrue = [](const APInt &) { return true; };
   auto CheckFalse = [](const APInt &) { return false; };
@@ -625,39 +625,49 @@ TEST_F(PatternMatchTest, CheckedInt) {
     APInt APVal(8, Val);
     Constant *C = ConstantInt::get(I8Ty, Val);
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
-    EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+    EXPECT_TRUE(m_CheckedInt(CRes, CheckTrue).match(C));
+    EXPECT_NE(CRes, nullptr);
+    EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
     EXPECT_EQ(*Res, APVal);
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
-    EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+    EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(C));
+    EXPECT_EQ(CRes, nullptr);
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
-    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CRes, CheckUgt1).match(C));
     if (CheckUgt1(APVal)) {
-      EXPECT_NE(Res, nullptr);
+      EXPECT_NE(CRes, nullptr);
+      EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
       EXPECT_EQ(*Res, APVal);
     }
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
-    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CRes, CheckNonZero).match(C));
     if (CheckNonZero(APVal)) {
-      EXPECT_NE(Res, nullptr);
+      EXPECT_NE(CRes, nullptr);
+      EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
       EXPECT_EQ(*Res, APVal);
     }
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
-    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CRes, CheckPow2).match(C));
     if (CheckPow2(APVal)) {
-      EXPECT_NE(Res, nullptr);
+      EXPECT_NE(CRes, nullptr);
+      EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
       EXPECT_EQ(*Res, APVal);
     }
-
   };
 
   DoScalarCheck(0);
@@ -666,20 +676,20 @@ TEST_F(PatternMatchTest, CheckedInt) {
   DoScalarCheck(3);
 
   EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
-  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
-  EXPECT_EQ(Res, nullptr);
+  EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(CRes, nullptr);
 
   EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
-  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
-  EXPECT_EQ(Res, nullptr);
+  EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(CRes, nullptr);
 
   EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
-  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
-  EXPECT_EQ(Res, nullptr);
+  EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(CRes, nullptr);
 
   EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
-  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
-  EXPECT_EQ(Res, nullptr);
+  EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(CRes, nullptr);
 
   auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
                             function_ref<bool(const APInt &)> CheckFn,
@@ -711,13 +721,16 @@ TEST_F(PatternMatchTest, CheckedInt) {
     EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
               m_CheckedInt(CheckFn).match(C));
 
+    CRes = nullptr;
     Res = nullptr;
     bool Expec =
-        !(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
-    EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+        !(HasUndef && !UndefAsPoison) && Okay.value_or(false);
+    EXPECT_EQ(Expec, m_CheckedInt(CRes, CheckFn).match(C));
     if (Expec) {
-      EXPECT_NE(Res, nullptr);
-      EXPECT_EQ(*Res, *First);
+      EXPECT_NE(CRes, nullptr);
+      EXPECT_EQ(match(CRes, m_APIntAllowPoison(Res)), AllSame);
+      if (AllSame)
+        EXPECT_EQ(*Res, *First);
     }
   };
   auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
@@ -1559,24 +1572,25 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));
 
   const APFloat *C;
+  const Constant *CC;
   // Regardless of whether poison is allowed,
   // a fully undef/poison constant does not match.
   EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
   EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
-  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CC, CheckTrue)));
   EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
-  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CC, CheckTrue)));
   EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
   EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
-  EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
+  EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(CC, CheckTrue)));
   EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
   EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
-  EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));
+  EXPECT_FALSE(match(VectorPoison, m_CheckedFp(CC, CheckTrue)));
 
   // We can always match simple constants and simple splats.
   C = nullptr;
@@ -1597,12 +1611,13 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   C = nullptr;
   EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
   EXPECT_TRUE(C->isZero());
-  C = nullptr;
-  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
-  EXPECT_TRUE(C->isZero());
-  C = nullptr;
-  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
-  EXPECT_TRUE(C->isZero());
+
+  CC = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckTrue)));
+  EXPECT_TRUE(CC->isNullValue());
+  CC = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckNonNaN)));
+  EXPECT_TRUE(CC->isNullValue());
 
   // Splats with undef are never allowed.
   // Whether splats with poison can be matched depends on the matcher.
@@ -1627,11 +1642,17 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   C = nullptr;
   EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
   EXPECT_TRUE(C->isZero());
+  CC = nullptr;
   C = nullptr;
-  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
+  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckTrue)));
+  EXPECT_NE(CC, nullptr);
+  EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
   EXPECT_TRUE(C->isZero());
+  CC = nullptr;
   C = nullptr;
-  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
+  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckNonNaN)));
+  EXPECT_NE(CC, nullptr);
+  EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
   EXPECT_TRUE(C->isZero());
 }
 

Copy link

github-actions bot commented May 7, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 45fed80b15df85cee53d3d31a7a46ae0daa91a3f 5e4f00d8a82cbaf37e32ebf6d9bfd1058115e05c -- llvm/include/llvm/IR/PatternMatch.h llvm/unittests/IR/PatternMatch.cpp
View the diff from clang-format here.
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 9f91b4f3f9..6a035553f7 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -613,7 +613,7 @@ TEST_F(PatternMatchTest, BitCast) {
 
 TEST_F(PatternMatchTest, CheckedInt) {
   Type *I8Ty = IRB.getInt8Ty();
-  const Constant * CRes = nullptr;
+  const Constant *CRes = nullptr;
   auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
   auto CheckTrue = [](const APInt &) { return true; };
   auto CheckFalse = [](const APInt &) { return false; };

Comment on lines 632 to 633
EXPECT_NE(CRes, nullptr);
EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
EXPECT_NE(CRes, nullptr);
EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
EXPECT_EQ(CRes, C);

I think we should be checking the constant directly, not the APInt it contains.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, how it is now was mostly based on the simplest changes...

…t instead of `APInt *`

The `APInt *` version is pretty useless as any case one needs an
`APInt *` out, they could just replace whatever they have the
`m_Checked...` lambda with direct checks on the `APInt`.

Leaving other helpers such as `m_Negative`, `m_Power2`,
etc... unchanged as the `APInt` out version is used mostly for
convenience and rarely change functionality when converted output a
`Constant *`.
@goldsteinn goldsteinn force-pushed the goldsteinn/ir-fixup-checked-output branch from 98fe785 to 5e4f00d Compare May 8, 2024 16:19
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/ formatting fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants