Skip to content

Commit

Permalink
[libc] Add bigint casting between word types (llvm#111914)
Browse files Browse the repository at this point in the history
Previously you could cast between bigints with different numbers of
bits, but only if they had the same underlying type. This patch adds the
ability to cast between bigints with different underlying types, which
is needed for llvm#110894
  • Loading branch information
michaelrj-google authored and EricWF committed Oct 22, 2024
1 parent 78f0cf8 commit 7014cf4
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 12 deletions.
99 changes: 88 additions & 11 deletions libc/src/__support/big_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "src/__support/CPP/limits.h"
#include "src/__support/CPP/optional.h"
#include "src/__support/CPP/type_traits.h"
#include "src/__support/macros/attributes.h" // LIBC_INLINE
#include "src/__support/macros/attributes.h" // LIBC_INLINE
#include "src/__support/macros/config.h"
#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
#include "src/__support/macros/properties/compiler.h" // LIBC_COMPILER_IS_CLANG
Expand Down Expand Up @@ -361,17 +361,94 @@ struct BigInt {

LIBC_INLINE constexpr BigInt(const BigInt &other) = default;

template <size_t OtherBits, bool OtherSigned>
template <size_t OtherBits, bool OtherSigned, typename OtherWordType>
LIBC_INLINE constexpr BigInt(
const BigInt<OtherBits, OtherSigned, WordType> &other) {
if (OtherBits >= Bits) { // truncate
for (size_t i = 0; i < WORD_COUNT; ++i)
val[i] = other[i];
} else { // zero or sign extend
size_t i = 0;
for (; i < OtherBits / WORD_SIZE; ++i)
val[i] = other[i];
extend(i, Signed && other.is_neg());
const BigInt<OtherBits, OtherSigned, OtherWordType> &other) {
using BigIntOther = BigInt<OtherBits, OtherSigned, OtherWordType>;
const bool should_sign_extend = Signed && other.is_neg();

static_assert(!(Bits == OtherBits && WORD_SIZE != BigIntOther::WORD_SIZE) &&
"This is currently untested for casting between bigints with "
"the same bit width but different word sizes.");

if constexpr (BigIntOther::WORD_SIZE < WORD_SIZE) {
// OtherWordType is smaller
constexpr size_t WORD_SIZE_RATIO = WORD_SIZE / BigIntOther::WORD_SIZE;
static_assert(
(WORD_SIZE % BigIntOther::WORD_SIZE) == 0 &&
"Word types must be multiples of each other for correct conversion.");
if constexpr (OtherBits >= Bits) { // truncate
// for each big word
for (size_t i = 0; i < WORD_COUNT; ++i) {
WordType cur_word = 0;
// combine WORD_SIZE_RATIO small words into a big word
for (size_t j = 0; j < WORD_SIZE_RATIO; ++j)
cur_word |= static_cast<WordType>(other[(i * WORD_SIZE_RATIO) + j])
<< (BigIntOther::WORD_SIZE * j);

val[i] = cur_word;
}
} else { // zero or sign extend
size_t i = 0;
WordType cur_word = 0;
// for each small word
for (; i < BigIntOther::WORD_COUNT; ++i) {
// combine WORD_SIZE_RATIO small words into a big word
cur_word |= static_cast<WordType>(other[i])
<< (BigIntOther::WORD_SIZE * (i % WORD_SIZE_RATIO));
// if we've completed a big word, copy it into place and reset
if ((i % WORD_SIZE_RATIO) == WORD_SIZE_RATIO - 1) {
val[i / WORD_SIZE_RATIO] = cur_word;
cur_word = 0;
}
}
// Pretend there are extra words of the correct sign extension as needed

const WordType extension_bits =
should_sign_extend ? cpp::numeric_limits<WordType>::max()
: cpp::numeric_limits<WordType>::min();
if ((i % WORD_SIZE_RATIO) != 0) {
cur_word |= static_cast<WordType>(extension_bits)
<< (BigIntOther::WORD_SIZE * (i % WORD_SIZE_RATIO));
}
// Copy the last word into place.
val[(i / WORD_SIZE_RATIO)] = cur_word;
extend((i / WORD_SIZE_RATIO) + 1, should_sign_extend);
}
} else if constexpr (BigIntOther::WORD_SIZE == WORD_SIZE) {
if constexpr (OtherBits >= Bits) { // truncate
for (size_t i = 0; i < WORD_COUNT; ++i)
val[i] = other[i];
} else { // zero or sign extend
size_t i = 0;
for (; i < BigIntOther::WORD_COUNT; ++i)
val[i] = other[i];
extend(i, should_sign_extend);
}
} else {
// OtherWordType is bigger.
constexpr size_t WORD_SIZE_RATIO = BigIntOther::WORD_SIZE / WORD_SIZE;
static_assert(
(BigIntOther::WORD_SIZE % WORD_SIZE) == 0 &&
"Word types must be multiples of each other for correct conversion.");
if constexpr (OtherBits >= Bits) { // truncate
// for each small word
for (size_t i = 0; i < WORD_COUNT; ++i) {
// split each big word into WORD_SIZE_RATIO small words
val[i] = static_cast<WordType>(other[i / WORD_SIZE_RATIO] >>
((i % WORD_SIZE_RATIO) * WORD_SIZE));
}
} else { // zero or sign extend
size_t i = 0;
// for each big word
for (; i < BigIntOther::WORD_COUNT; ++i) {
// split each big word into WORD_SIZE_RATIO small words
for (size_t j = 0; j < WORD_SIZE_RATIO; ++j)
val[(i * WORD_SIZE_RATIO) + j] =
static_cast<WordType>(other[i] >> (j * WORD_SIZE));
}
extend(i * WORD_SIZE_RATIO, should_sign_extend);
}
}
}

Expand Down
142 changes: 141 additions & 1 deletion libc/test/src/__support/big_int_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include "src/__support/CPP/optional.h"
#include "src/__support/big_int.h"
#include "src/__support/integer_literals.h" // parse_unsigned_bigint
#include "src/__support/integer_literals.h" // parse_unsigned_bigint
#include "src/__support/macros/config.h"
#include "src/__support/macros/properties/types.h" // LIBC_TYPES_HAS_INT128

Expand Down Expand Up @@ -208,6 +208,7 @@ TYPED_TEST(LlvmLibcUIntClassTest, CountBits, Types) {
}

using LL_UInt16 = UInt<16>;
using LL_UInt32 = UInt<32>;
using LL_UInt64 = UInt<64>;
// We want to test UInt<128> explicitly. So, for
// convenience, we use a sugar which does not conflict with the UInt128 type
Expand Down Expand Up @@ -927,4 +928,143 @@ TEST(LlvmLibcUIntClassTest, OtherWordTypeTests) {
ASSERT_EQ(static_cast<int>(a >> 64), 1);
}

TEST(LlvmLibcUIntClassTest, OtherWordTypeCastTests) {
using LL_UInt96 = BigInt<96, false, uint32_t>;

LL_UInt96 a({123, 456, 789});

ASSERT_EQ(static_cast<int>(a), 123);
ASSERT_EQ(static_cast<int>(a >> 32), 456);
ASSERT_EQ(static_cast<int>(a >> 64), 789);

// Bigger word with more bits to smaller word with less bits.
LL_UInt128 b(a);

ASSERT_EQ(static_cast<int>(b), 123);
ASSERT_EQ(static_cast<int>(b >> 32), 456);
ASSERT_EQ(static_cast<int>(b >> 64), 789);
ASSERT_EQ(static_cast<int>(b >> 96), 0);

b = (b << 32) + 987;

ASSERT_EQ(static_cast<int>(b), 987);
ASSERT_EQ(static_cast<int>(b >> 32), 123);
ASSERT_EQ(static_cast<int>(b >> 64), 456);
ASSERT_EQ(static_cast<int>(b >> 96), 789);

// Smaller word with less bits to bigger word with more bits.
LL_UInt96 c(b);

ASSERT_EQ(static_cast<int>(c), 987);
ASSERT_EQ(static_cast<int>(c >> 32), 123);
ASSERT_EQ(static_cast<int>(c >> 64), 456);

// Smaller word with more bits to bigger word with less bits
LL_UInt64 d(c);

ASSERT_EQ(static_cast<int>(d), 987);
ASSERT_EQ(static_cast<int>(d >> 32), 123);

// Bigger word with less bits to smaller word with more bits

LL_UInt96 e(d);

ASSERT_EQ(static_cast<int>(e), 987);
ASSERT_EQ(static_cast<int>(e >> 32), 123);

e = (e << 32) + 654;

ASSERT_EQ(static_cast<int>(e), 654);
ASSERT_EQ(static_cast<int>(e >> 32), 987);
ASSERT_EQ(static_cast<int>(e >> 64), 123);
}

TEST(LlvmLibcUIntClassTest, SignedOtherWordTypeCastTests) {
using LL_Int64 = BigInt<64, true, uint64_t>;
using LL_Int96 = BigInt<96, true, uint32_t>;

LL_Int64 zero_64(0);
LL_Int96 zero_96(0);
LL_Int192 zero_192(0);

LL_Int96 plus_a({0x1234, 0x5678, 0x9ABC});

ASSERT_EQ(static_cast<int>(plus_a), 0x1234);
ASSERT_EQ(static_cast<int>(plus_a >> 32), 0x5678);
ASSERT_EQ(static_cast<int>(plus_a >> 64), 0x9ABC);

LL_Int96 minus_a(-plus_a);

// The reason that the numbers are inverted and not negated is that we're
// using two's complement. To negate a two's complement number you flip the
// bits and add 1, so minus_a is {~0x1234, ~0x5678, ~0x9ABC} + {1,0,0}.
ASSERT_EQ(static_cast<int>(minus_a), (~0x1234) + 1);
ASSERT_EQ(static_cast<int>(minus_a >> 32), ~0x5678);
ASSERT_EQ(static_cast<int>(minus_a >> 64), ~0x9ABC);

ASSERT_TRUE(plus_a + minus_a == zero_96);

// 192 so there's an extra block to get sign extended to
LL_Int192 bigger_plus_a(plus_a);

ASSERT_EQ(static_cast<int>(bigger_plus_a), 0x1234);
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 32), 0x5678);
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 64), 0x9ABC);
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 96), 0);
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 128), 0);
ASSERT_EQ(static_cast<int>(bigger_plus_a >> 160), 0);

LL_Int192 bigger_minus_a(minus_a);

ASSERT_EQ(static_cast<int>(bigger_minus_a), (~0x1234) + 1);
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 32), ~0x5678);
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 64), ~0x9ABC);
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 96), ~0);
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 128), ~0);
ASSERT_EQ(static_cast<int>(bigger_minus_a >> 160), ~0);

ASSERT_TRUE(bigger_plus_a + bigger_minus_a == zero_192);

LL_Int64 smaller_plus_a(plus_a);

ASSERT_EQ(static_cast<int>(smaller_plus_a), 0x1234);
ASSERT_EQ(static_cast<int>(smaller_plus_a >> 32), 0x5678);

LL_Int64 smaller_minus_a(minus_a);

ASSERT_EQ(static_cast<int>(smaller_minus_a), (~0x1234) + 1);
ASSERT_EQ(static_cast<int>(smaller_minus_a >> 32), ~0x5678);

ASSERT_TRUE(smaller_plus_a + smaller_minus_a == zero_64);

// Also try going from bigger word size to smaller word size
LL_Int96 smaller_back_plus_a(smaller_plus_a);

ASSERT_EQ(static_cast<int>(smaller_back_plus_a), 0x1234);
ASSERT_EQ(static_cast<int>(smaller_back_plus_a >> 32), 0x5678);
ASSERT_EQ(static_cast<int>(smaller_back_plus_a >> 64), 0);

LL_Int96 smaller_back_minus_a(smaller_minus_a);

ASSERT_EQ(static_cast<int>(smaller_back_minus_a), (~0x1234) + 1);
ASSERT_EQ(static_cast<int>(smaller_back_minus_a >> 32), ~0x5678);
ASSERT_EQ(static_cast<int>(smaller_back_minus_a >> 64), ~0);

ASSERT_TRUE(smaller_back_plus_a + smaller_back_minus_a == zero_96);

LL_Int96 bigger_back_plus_a(bigger_plus_a);

ASSERT_EQ(static_cast<int>(bigger_back_plus_a), 0x1234);
ASSERT_EQ(static_cast<int>(bigger_back_plus_a >> 32), 0x5678);
ASSERT_EQ(static_cast<int>(bigger_back_plus_a >> 64), 0x9ABC);

LL_Int96 bigger_back_minus_a(bigger_minus_a);

ASSERT_EQ(static_cast<int>(bigger_back_minus_a), (~0x1234) + 1);
ASSERT_EQ(static_cast<int>(bigger_back_minus_a >> 32), ~0x5678);
ASSERT_EQ(static_cast<int>(bigger_back_minus_a >> 64), ~0x9ABC);

ASSERT_TRUE(bigger_back_plus_a + bigger_back_minus_a == zero_96);
}

} // namespace LIBC_NAMESPACE_DECL

0 comments on commit 7014cf4

Please sign in to comment.