From a3f7c8ecc5fdf7ee27cf96f2c3e365b9552ecc7c Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Mon, 11 Mar 2024 16:38:44 +0000 Subject: [PATCH] [ADT] Add APIntOps::abds signed absolute difference and rename absdiff -> abdu When I created APIntOps::absdiff, I totally missed that we already have ISD::ABDS/ABDU nodes, and we use this term in other places/targets as well. I've added the APIntOps::abds implementation and renamed APIntOps::absdiff to APIntOps::abdu. Given that APIntOps::absdiff is so young I don't think we need to create a deprecation wrapper, but I can if anyone thinks it important. I'll do a KnownBits rename patch after this. --- llvm/include/llvm/ADT/APInt.h | 7 +- .../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 4 +- llvm/unittests/ADT/APIntTest.cpp | 66 ++++++++++++++----- llvm/unittests/Support/KnownBitsTest.cpp | 2 +- 4 files changed, 59 insertions(+), 20 deletions(-) diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h index 1fc3c7b2236a17..bea3e28adf308f 100644 --- a/llvm/include/llvm/ADT/APInt.h +++ b/llvm/include/llvm/ADT/APInt.h @@ -2188,8 +2188,13 @@ inline const APInt &umax(const APInt &A, const APInt &B) { return A.ugt(B) ? A : B; } +/// Determine the absolute difference of two APInts considered to be signed. +inline const APInt abds(const APInt &A, const APInt &B) { + return A.sge(B) ? (A - B) : (B - A); +} + /// Determine the absolute difference of two APInts considered to be unsigned. -inline const APInt absdiff(const APInt &A, const APInt &B) { +inline const APInt abdu(const APInt &A, const APInt &B) { return A.uge(B) ? (A - B) : (B - A); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 06fe716a22db0f..ccd4025fafe796 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6066,9 +6066,9 @@ static std::optional FoldValue(unsigned Opcode, const APInt &C1, return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1); } case ISD::ABDS: - return APIntOps::smax(C1, C2) - APIntOps::smin(C1, C2); + return APIntOps::abds(C1, C2); case ISD::ABDU: - return APIntOps::umax(C1, C2) - APIntOps::umin(C1, C2); + return APIntOps::abdu(C1, C2); } return std::nullopt; } diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp index 24324822356bf6..11237d2e1602dc 100644 --- a/llvm/unittests/ADT/APIntTest.cpp +++ b/llvm/unittests/ADT/APIntTest.cpp @@ -2532,38 +2532,72 @@ TEST(APIntTest, clearLowBits) { EXPECT_EQ(16u, i32hi16.popcount()); } -TEST(APIntTest, AbsDiff) { - using APIntOps::absdiff; +TEST(APIntTest, abds) { + using APIntOps::abds; APInt MaxU1(1, 1, false); APInt MinU1(1, 0, false); - EXPECT_EQ(1u, absdiff(MaxU1, MinU1).getZExtValue()); - EXPECT_EQ(1u, absdiff(MinU1, MaxU1).getZExtValue()); + EXPECT_EQ(1u, abds(MaxU1, MinU1).getZExtValue()); + EXPECT_EQ(1u, abds(MinU1, MaxU1).getZExtValue()); APInt MaxU4(4, 15, false); APInt MinU4(4, 0, false); - EXPECT_EQ(15u, absdiff(MaxU4, MinU4).getZExtValue()); - EXPECT_EQ(15u, absdiff(MinU4, MaxU4).getZExtValue()); + EXPECT_EQ(1, abds(MaxU4, MinU4).getSExtValue()); + EXPECT_EQ(1, abds(MinU4, MaxU4).getSExtValue()); APInt MaxS8(8, 127, true); APInt MinS8(8, -128, true); - EXPECT_EQ(1u, absdiff(MaxS8, MinS8).getZExtValue()); - EXPECT_EQ(1u, absdiff(MinS8, MaxS8).getZExtValue()); + EXPECT_EQ(-1, abds(MaxS8, MinS8).getSExtValue()); + EXPECT_EQ(-1, abds(MinS8, MaxS8).getSExtValue()); APInt MaxU16(16, 65535, false); APInt MinU16(16, 0, false); - EXPECT_EQ(65535u, absdiff(MaxU16, MinU16).getZExtValue()); - EXPECT_EQ(65535u, absdiff(MinU16, MaxU16).getZExtValue()); + EXPECT_EQ(1, abds(MaxU16, MinU16).getSExtValue()); + EXPECT_EQ(1, abds(MinU16, MaxU16).getSExtValue()); APInt MaxS16(16, 32767, true); APInt MinS16(16, -32768, true); APInt ZeroS16(16, 0, true); - EXPECT_EQ(1u, absdiff(MaxS16, MinS16).getZExtValue()); - EXPECT_EQ(1u, absdiff(MinS16, MaxS16).getZExtValue()); - EXPECT_EQ(32768u, absdiff(ZeroS16, MinS16)); - EXPECT_EQ(32768u, absdiff(MinS16, ZeroS16)); - EXPECT_EQ(32767u, absdiff(ZeroS16, MaxS16)); - EXPECT_EQ(32767u, absdiff(MaxS16, ZeroS16)); + EXPECT_EQ(-1, abds(MaxS16, MinS16).getSExtValue()); + EXPECT_EQ(-1, abds(MinS16, MaxS16).getSExtValue()); + EXPECT_EQ(32768u, abds(ZeroS16, MinS16)); + EXPECT_EQ(32768u, abds(MinS16, ZeroS16)); + EXPECT_EQ(32767u, abds(ZeroS16, MaxS16)); + EXPECT_EQ(32767u, abds(MaxS16, ZeroS16)); +} + +TEST(APIntTest, abdu) { + using APIntOps::abdu; + + APInt MaxU1(1, 1, false); + APInt MinU1(1, 0, false); + EXPECT_EQ(1u, abdu(MaxU1, MinU1).getZExtValue()); + EXPECT_EQ(1u, abdu(MinU1, MaxU1).getZExtValue()); + + APInt MaxU4(4, 15, false); + APInt MinU4(4, 0, false); + EXPECT_EQ(15u, abdu(MaxU4, MinU4).getZExtValue()); + EXPECT_EQ(15u, abdu(MinU4, MaxU4).getZExtValue()); + + APInt MaxS8(8, 127, true); + APInt MinS8(8, -128, true); + EXPECT_EQ(1u, abdu(MaxS8, MinS8).getZExtValue()); + EXPECT_EQ(1u, abdu(MinS8, MaxS8).getZExtValue()); + + APInt MaxU16(16, 65535, false); + APInt MinU16(16, 0, false); + EXPECT_EQ(65535u, abdu(MaxU16, MinU16).getZExtValue()); + EXPECT_EQ(65535u, abdu(MinU16, MaxU16).getZExtValue()); + + APInt MaxS16(16, 32767, true); + APInt MinS16(16, -32768, true); + APInt ZeroS16(16, 0, true); + EXPECT_EQ(1u, abdu(MaxS16, MinS16).getZExtValue()); + EXPECT_EQ(1u, abdu(MinS16, MaxS16).getZExtValue()); + EXPECT_EQ(32768u, abdu(ZeroS16, MinS16)); + EXPECT_EQ(32768u, abdu(MinS16, ZeroS16)); + EXPECT_EQ(32767u, abdu(ZeroS16, MaxS16)); + EXPECT_EQ(32767u, abdu(MaxS16, ZeroS16)); } TEST(APIntTest, GCD) { diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index 658f3796721c4e..701876c7e418cd 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -362,7 +362,7 @@ TEST(KnownBitsTest, BinaryExhaustive) { return KnownBits::absdiff(Known1, Known2); }, [](const APInt &N1, const APInt &N2) { - return APIntOps::absdiff(N1, N2); + return APIntOps::abdu(N1, N2); }, checkCorrectnessOnlyBinary); testBinaryOpExhaustive(