Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Support] Add KnownBits::abds signed absolute difference and rename absdiff -> abdu #84897

Merged
merged 1 commit into from
Mar 12, 2024

Conversation

RKSimon
Copy link
Collaborator

@RKSimon RKSimon commented Mar 12, 2024

When I created KnownBits::absdiff, I totally missed that we already have ISD::ABDS/ABDU nodes, and we use this term in other places/targets as well.

I've added the KnownBits::abds implementation and renamed KnownBits::absdiff to KnownBits::abdu.

Followup to #84791

@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2024

@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-llvm-support

Author: Simon Pilgrim (RKSimon)

Changes

When I created KnownBits::absdiff, I totally missed that we already have ISD::ABDS/ABDU nodes, and we use this term in other places/targets as well.

I've added the KnownBits::abds implementation and renamed KnownBits::absdiff to KnownBits::abdu.

Followup to #84791


Full diff: https://github.com/llvm/llvm-project/pull/84897.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Support/KnownBits.h (+5-2)
  • (modified) llvm/lib/Support/KnownBits.cpp (+21-2)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
  • (modified) llvm/unittests/Support/KnownBitsTest.cpp (+38-6)
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index 06d2c90f7b0f6b..73cb01e0644a8d 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -390,8 +390,11 @@ struct KnownBits {
   /// Compute known bits for smin(LHS, RHS).
   static KnownBits smin(const KnownBits &LHS, const KnownBits &RHS);
 
-  /// Compute known bits for absdiff(LHS, RHS).
-  static KnownBits absdiff(const KnownBits &LHS, const KnownBits &RHS);
+  /// Compute known bits for abdu(LHS, RHS).
+  static KnownBits abdu(const KnownBits &LHS, const KnownBits &RHS);
+
+  /// Compute known bits for abds(LHS, RHS).
+  static KnownBits abds(const KnownBits &LHS, const KnownBits &RHS);
 
   /// Compute known bits for shl(LHS, RHS).
   /// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index c33c3680825a10..d72355dab6f1d3 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -231,8 +231,8 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
   return Flip(umax(Flip(LHS), Flip(RHS)));
 }
 
-KnownBits KnownBits::absdiff(const KnownBits &LHS, const KnownBits &RHS) {
-  // absdiff(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
+KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
+  // abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
   KnownBits UMaxValue = umax(LHS, RHS);
   KnownBits UMinValue = umin(LHS, RHS);
   KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
@@ -250,6 +250,25 @@ KnownBits KnownBits::absdiff(const KnownBits &LHS, const KnownBits &RHS) {
   return KnownAbsDiff;
 }
 
+KnownBits KnownBits::abds(const KnownBits &LHS, const KnownBits &RHS) {
+  // abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
+  KnownBits SMaxValue = smax(LHS, RHS);
+  KnownBits SMinValue = smin(LHS, RHS);
+  KnownBits MinMaxDiff = computeForAddSub(/*Add=*/false, /*NSW=*/false,
+                                          /*NUW=*/false, SMaxValue, SMinValue);
+
+  // find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
+  KnownBits Diff0 =
+      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
+  KnownBits Diff1 =
+      computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, RHS, LHS);
+  KnownBits SubDiff = Diff0.intersectWith(Diff1);
+
+  KnownBits KnownAbsDiff = MinMaxDiff.unionWith(SubDiff);
+  assert(!KnownAbsDiff.hasConflict() && "Bad Output");
+  return KnownAbsDiff;
+}
+
 static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
   if (isPowerOf2_32(BitWidth))
     return MaxValue.extractBitsAsZExtValue(Log2_32(BitWidth), 0);
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a74901958ac056..b4d0421c14c0b6 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -36753,7 +36753,7 @@ static void computeKnownBitsForPSADBW(SDValue LHS, SDValue RHS,
   APInt DemandedSrcElts = APIntOps::ScaleBitMask(DemandedElts, NumSrcElts);
   Known = DAG.computeKnownBits(RHS, DemandedSrcElts, Depth + 1);
   Known2 = DAG.computeKnownBits(LHS, DemandedSrcElts, Depth + 1);
-  Known = KnownBits::absdiff(Known, Known2).zext(16);
+  Known = KnownBits::abdu(Known, Known2).zext(16);
   // Known = (((D0 + D1) + (D2 + D3)) + ((D4 + D5) + (D6 + D7)))
   Known = KnownBits::computeForAddSub(/*Add=*/true, /*NSW=*/true, /*NUW=*/true,
                                       Known, Known);
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index 2ac25f0b2801ba..939749b931a80a 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -294,18 +294,18 @@ TEST(KnownBitsTest, SignBitUnknown) {
   EXPECT_TRUE(Known.isSignUnknown());
 }
 
-TEST(KnownBitsTest, AbsDiffSpecialCase) {
-  // There are 2 implementation of absdiff - both are currently needed to cover
+TEST(KnownBitsTest, ABDUSpecialCase) {
+  // There are 2 implementation of abdu - both are currently needed to cover
   // extra cases.
   KnownBits LHS, RHS, Res;
 
-  // absdiff(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
+  // abdu(LHS,RHS) = sub(umax(LHS,RHS), umin(LHS,RHS)).
   // Actual: false (Inputs = 1011, 101?, Computed = 000?, Exact = 000?)
   LHS.One = APInt(4, 0b1011);
   RHS.One = APInt(4, 0b1010);
   LHS.Zero = APInt(4, 0b0100);
   RHS.Zero = APInt(4, 0b0100);
-  Res = KnownBits::absdiff(LHS, RHS);
+  Res = KnownBits::abdu(LHS, RHS);
   EXPECT_EQ(0b0000ul, Res.One.getZExtValue());
   EXPECT_EQ(0b1110ul, Res.Zero.getZExtValue());
 
@@ -315,11 +315,37 @@ TEST(KnownBitsTest, AbsDiffSpecialCase) {
   RHS.One = APInt(4, 0b1000);
   LHS.Zero = APInt(4, 0b0000);
   RHS.Zero = APInt(4, 0b0111);
-  Res = KnownBits::absdiff(LHS, RHS);
+  Res = KnownBits::abdu(LHS, RHS);
   EXPECT_EQ(0b0001ul, Res.One.getZExtValue());
   EXPECT_EQ(0b0000ul, Res.Zero.getZExtValue());
 }
 
+TEST(KnownBitsTest, ABDSSpecialCase) {
+  // There are 2 implementation of abds - both are currently needed to cover
+  // extra cases.
+  KnownBits LHS, RHS, Res;
+
+  // abds(LHS,RHS) = sub(smax(LHS,RHS), smin(LHS,RHS)).
+  // Actual: false (Inputs = 1011, 10??, Computed = ????, Exact = 00??)
+  LHS.One = APInt(4, 0b1011);
+  RHS.One = APInt(4, 0b1000);
+  LHS.Zero = APInt(4, 0b0100);
+  RHS.Zero = APInt(4, 0b0100);
+  Res = KnownBits::abds(LHS, RHS);
+  EXPECT_EQ(0, Res.One.getSExtValue());
+  EXPECT_EQ(-4, Res.Zero.getSExtValue());
+
+  // find the common bits between sub(LHS,RHS) and sub(RHS,LHS).
+  // Actual: false (Inputs = ???1, 1000, Computed = ???1, Exact = 0??1)
+  LHS.One = APInt(4, 0b0001);
+  RHS.One = APInt(4, 0b1000);
+  LHS.Zero = APInt(4, 0b0000);
+  RHS.Zero = APInt(4, 0b0111);
+  Res = KnownBits::abds(LHS, RHS);
+  EXPECT_EQ(1, Res.One.getSExtValue());
+  EXPECT_EQ(0, Res.Zero.getSExtValue());
+}
+
 TEST(KnownBitsTest, BinaryExhaustive) {
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
@@ -359,10 +385,16 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const APInt &N1, const APInt &N2) { return APIntOps::smin(N1, N2); });
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
-        return KnownBits::absdiff(Known1, Known2);
+        return KnownBits::abdu(Known1, Known2);
       },
       [](const APInt &N1, const APInt &N2) { return APIntOps::abdu(N1, N2); },
       checkCorrectnessOnlyBinary);
+  testBinaryOpExhaustive(
+      [](const KnownBits &Known1, const KnownBits &Known2) {
+        return KnownBits::abds(Known1, Known2);
+      },
+      [](const APInt &N1, const APInt &N2) { return APIntOps::abds(N1, N2); },
+      checkCorrectnessOnlyBinary);
   testBinaryOpExhaustive(
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::udiv(Known1, Known2);

Copy link
Contributor

@jayfoad jayfoad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

…bsdiff -> abdu

When I created KnownBits::absdiff, I totally missed that we already have ISD::ABDS/ABDU nodes, and we use this term in other places/targets as well.

I've added the KnownBits::abds implementation and renamed KnownBits::absdiff to KnownBits::abdu.

Followup to llvm#84791
@RKSimon RKSimon merged commit ffe4181 into llvm:main Mar 12, 2024
3 of 4 checks passed
@RKSimon RKSimon deleted the knownbits-abd branch March 12, 2024 12:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants