diff --git a/libc/src/__support/big_int.h b/libc/src/__support/big_int.h index 681782d57319e5e..246b89f08f2ff95 100644 --- a/libc/src/__support/big_int.h +++ b/libc/src/__support/big_int.h @@ -14,7 +14,7 @@ #include "src/__support/CPP/limits.h" #include "src/__support/CPP/optional.h" #include "src/__support/CPP/type_traits.h" -#include "src/__support/macros/attributes.h" // LIBC_INLINE +#include "src/__support/macros/attributes.h" // LIBC_INLINE #include "src/__support/macros/config.h" #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY #include "src/__support/macros/properties/compiler.h" // LIBC_COMPILER_IS_CLANG @@ -361,17 +361,94 @@ struct BigInt { LIBC_INLINE constexpr BigInt(const BigInt &other) = default; - template + template LIBC_INLINE constexpr BigInt( - const BigInt &other) { - if (OtherBits >= Bits) { // truncate - for (size_t i = 0; i < WORD_COUNT; ++i) - val[i] = other[i]; - } else { // zero or sign extend - size_t i = 0; - for (; i < OtherBits / WORD_SIZE; ++i) - val[i] = other[i]; - extend(i, Signed && other.is_neg()); + const BigInt &other) { + using BigIntOther = BigInt; + const bool should_sign_extend = Signed && other.is_neg(); + + static_assert(!(Bits == OtherBits && WORD_SIZE != BigIntOther::WORD_SIZE) && + "This is currently untested for casting between bigints with " + "the same bit width but different word sizes."); + + if constexpr (BigIntOther::WORD_SIZE < WORD_SIZE) { + // OtherWordType is smaller + constexpr size_t WORD_SIZE_RATIO = WORD_SIZE / BigIntOther::WORD_SIZE; + static_assert( + (WORD_SIZE % BigIntOther::WORD_SIZE) == 0 && + "Word types must be multiples of each other for correct conversion."); + if constexpr (OtherBits >= Bits) { // truncate + // for each big word + for (size_t i = 0; i < WORD_COUNT; ++i) { + WordType cur_word = 0; + // combine WORD_SIZE_RATIO small words into a big word + for (size_t j = 0; j < WORD_SIZE_RATIO; ++j) + cur_word |= static_cast(other[(i * WORD_SIZE_RATIO) + j]) + << (BigIntOther::WORD_SIZE * j); + + val[i] = cur_word; + } + } else { // zero or sign extend + size_t i = 0; + WordType cur_word = 0; + // for each small word + for (; i < BigIntOther::WORD_COUNT; ++i) { + // combine WORD_SIZE_RATIO small words into a big word + cur_word |= static_cast(other[i]) + << (BigIntOther::WORD_SIZE * (i % WORD_SIZE_RATIO)); + // if we've completed a big word, copy it into place and reset + if ((i % WORD_SIZE_RATIO) == WORD_SIZE_RATIO - 1) { + val[i / WORD_SIZE_RATIO] = cur_word; + cur_word = 0; + } + } + // Pretend there are extra words of the correct sign extension as needed + + const WordType extension_bits = + should_sign_extend ? cpp::numeric_limits::max() + : cpp::numeric_limits::min(); + if ((i % WORD_SIZE_RATIO) != 0) { + cur_word |= static_cast(extension_bits) + << (BigIntOther::WORD_SIZE * (i % WORD_SIZE_RATIO)); + } + // Copy the last word into place. + val[(i / WORD_SIZE_RATIO)] = cur_word; + extend((i / WORD_SIZE_RATIO) + 1, should_sign_extend); + } + } else if constexpr (BigIntOther::WORD_SIZE == WORD_SIZE) { + if constexpr (OtherBits >= Bits) { // truncate + for (size_t i = 0; i < WORD_COUNT; ++i) + val[i] = other[i]; + } else { // zero or sign extend + size_t i = 0; + for (; i < BigIntOther::WORD_COUNT; ++i) + val[i] = other[i]; + extend(i, should_sign_extend); + } + } else { + // OtherWordType is bigger. + constexpr size_t WORD_SIZE_RATIO = BigIntOther::WORD_SIZE / WORD_SIZE; + static_assert( + (BigIntOther::WORD_SIZE % WORD_SIZE) == 0 && + "Word types must be multiples of each other for correct conversion."); + if constexpr (OtherBits >= Bits) { // truncate + // for each small word + for (size_t i = 0; i < WORD_COUNT; ++i) { + // split each big word into WORD_SIZE_RATIO small words + val[i] = static_cast(other[i / WORD_SIZE_RATIO] >> + ((i % WORD_SIZE_RATIO) * WORD_SIZE)); + } + } else { // zero or sign extend + size_t i = 0; + // for each big word + for (; i < BigIntOther::WORD_COUNT; ++i) { + // split each big word into WORD_SIZE_RATIO small words + for (size_t j = 0; j < WORD_SIZE_RATIO; ++j) + val[(i * WORD_SIZE_RATIO) + j] = + static_cast(other[i] >> (j * WORD_SIZE)); + } + extend(i * WORD_SIZE_RATIO, should_sign_extend); + } } } diff --git a/libc/test/src/__support/big_int_test.cpp b/libc/test/src/__support/big_int_test.cpp index a1ce69baaae2906..471ca72a8f6e0c8 100644 --- a/libc/test/src/__support/big_int_test.cpp +++ b/libc/test/src/__support/big_int_test.cpp @@ -8,7 +8,7 @@ #include "src/__support/CPP/optional.h" #include "src/__support/big_int.h" -#include "src/__support/integer_literals.h" // parse_unsigned_bigint +#include "src/__support/integer_literals.h" // parse_unsigned_bigint #include "src/__support/macros/config.h" #include "src/__support/macros/properties/types.h" // LIBC_TYPES_HAS_INT128 @@ -208,6 +208,7 @@ TYPED_TEST(LlvmLibcUIntClassTest, CountBits, Types) { } using LL_UInt16 = UInt<16>; +using LL_UInt32 = UInt<32>; using LL_UInt64 = UInt<64>; // We want to test UInt<128> explicitly. So, for // convenience, we use a sugar which does not conflict with the UInt128 type @@ -927,4 +928,143 @@ TEST(LlvmLibcUIntClassTest, OtherWordTypeTests) { ASSERT_EQ(static_cast(a >> 64), 1); } +TEST(LlvmLibcUIntClassTest, OtherWordTypeCastTests) { + using LL_UInt96 = BigInt<96, false, uint32_t>; + + LL_UInt96 a({123, 456, 789}); + + ASSERT_EQ(static_cast(a), 123); + ASSERT_EQ(static_cast(a >> 32), 456); + ASSERT_EQ(static_cast(a >> 64), 789); + + // Bigger word with more bits to smaller word with less bits. + LL_UInt128 b(a); + + ASSERT_EQ(static_cast(b), 123); + ASSERT_EQ(static_cast(b >> 32), 456); + ASSERT_EQ(static_cast(b >> 64), 789); + ASSERT_EQ(static_cast(b >> 96), 0); + + b = (b << 32) + 987; + + ASSERT_EQ(static_cast(b), 987); + ASSERT_EQ(static_cast(b >> 32), 123); + ASSERT_EQ(static_cast(b >> 64), 456); + ASSERT_EQ(static_cast(b >> 96), 789); + + // Smaller word with less bits to bigger word with more bits. + LL_UInt96 c(b); + + ASSERT_EQ(static_cast(c), 987); + ASSERT_EQ(static_cast(c >> 32), 123); + ASSERT_EQ(static_cast(c >> 64), 456); + + // Smaller word with more bits to bigger word with less bits + LL_UInt64 d(c); + + ASSERT_EQ(static_cast(d), 987); + ASSERT_EQ(static_cast(d >> 32), 123); + + // Bigger word with less bits to smaller word with more bits + + LL_UInt96 e(d); + + ASSERT_EQ(static_cast(e), 987); + ASSERT_EQ(static_cast(e >> 32), 123); + + e = (e << 32) + 654; + + ASSERT_EQ(static_cast(e), 654); + ASSERT_EQ(static_cast(e >> 32), 987); + ASSERT_EQ(static_cast(e >> 64), 123); +} + +TEST(LlvmLibcUIntClassTest, SignedOtherWordTypeCastTests) { + using LL_Int64 = BigInt<64, true, uint64_t>; + using LL_Int96 = BigInt<96, true, uint32_t>; + + LL_Int64 zero_64(0); + LL_Int96 zero_96(0); + LL_Int192 zero_192(0); + + LL_Int96 plus_a({0x1234, 0x5678, 0x9ABC}); + + ASSERT_EQ(static_cast(plus_a), 0x1234); + ASSERT_EQ(static_cast(plus_a >> 32), 0x5678); + ASSERT_EQ(static_cast(plus_a >> 64), 0x9ABC); + + LL_Int96 minus_a(-plus_a); + + // The reason that the numbers are inverted and not negated is that we're + // using two's complement. To negate a two's complement number you flip the + // bits and add 1, so minus_a is {~0x1234, ~0x5678, ~0x9ABC} + {1,0,0}. + ASSERT_EQ(static_cast(minus_a), (~0x1234) + 1); + ASSERT_EQ(static_cast(minus_a >> 32), ~0x5678); + ASSERT_EQ(static_cast(minus_a >> 64), ~0x9ABC); + + ASSERT_TRUE(plus_a + minus_a == zero_96); + + // 192 so there's an extra block to get sign extended to + LL_Int192 bigger_plus_a(plus_a); + + ASSERT_EQ(static_cast(bigger_plus_a), 0x1234); + ASSERT_EQ(static_cast(bigger_plus_a >> 32), 0x5678); + ASSERT_EQ(static_cast(bigger_plus_a >> 64), 0x9ABC); + ASSERT_EQ(static_cast(bigger_plus_a >> 96), 0); + ASSERT_EQ(static_cast(bigger_plus_a >> 128), 0); + ASSERT_EQ(static_cast(bigger_plus_a >> 160), 0); + + LL_Int192 bigger_minus_a(minus_a); + + ASSERT_EQ(static_cast(bigger_minus_a), (~0x1234) + 1); + ASSERT_EQ(static_cast(bigger_minus_a >> 32), ~0x5678); + ASSERT_EQ(static_cast(bigger_minus_a >> 64), ~0x9ABC); + ASSERT_EQ(static_cast(bigger_minus_a >> 96), ~0); + ASSERT_EQ(static_cast(bigger_minus_a >> 128), ~0); + ASSERT_EQ(static_cast(bigger_minus_a >> 160), ~0); + + ASSERT_TRUE(bigger_plus_a + bigger_minus_a == zero_192); + + LL_Int64 smaller_plus_a(plus_a); + + ASSERT_EQ(static_cast(smaller_plus_a), 0x1234); + ASSERT_EQ(static_cast(smaller_plus_a >> 32), 0x5678); + + LL_Int64 smaller_minus_a(minus_a); + + ASSERT_EQ(static_cast(smaller_minus_a), (~0x1234) + 1); + ASSERT_EQ(static_cast(smaller_minus_a >> 32), ~0x5678); + + ASSERT_TRUE(smaller_plus_a + smaller_minus_a == zero_64); + + // Also try going from bigger word size to smaller word size + LL_Int96 smaller_back_plus_a(smaller_plus_a); + + ASSERT_EQ(static_cast(smaller_back_plus_a), 0x1234); + ASSERT_EQ(static_cast(smaller_back_plus_a >> 32), 0x5678); + ASSERT_EQ(static_cast(smaller_back_plus_a >> 64), 0); + + LL_Int96 smaller_back_minus_a(smaller_minus_a); + + ASSERT_EQ(static_cast(smaller_back_minus_a), (~0x1234) + 1); + ASSERT_EQ(static_cast(smaller_back_minus_a >> 32), ~0x5678); + ASSERT_EQ(static_cast(smaller_back_minus_a >> 64), ~0); + + ASSERT_TRUE(smaller_back_plus_a + smaller_back_minus_a == zero_96); + + LL_Int96 bigger_back_plus_a(bigger_plus_a); + + ASSERT_EQ(static_cast(bigger_back_plus_a), 0x1234); + ASSERT_EQ(static_cast(bigger_back_plus_a >> 32), 0x5678); + ASSERT_EQ(static_cast(bigger_back_plus_a >> 64), 0x9ABC); + + LL_Int96 bigger_back_minus_a(bigger_minus_a); + + ASSERT_EQ(static_cast(bigger_back_minus_a), (~0x1234) + 1); + ASSERT_EQ(static_cast(bigger_back_minus_a >> 32), ~0x5678); + ASSERT_EQ(static_cast(bigger_back_minus_a >> 64), ~0x9ABC); + + ASSERT_TRUE(bigger_back_plus_a + bigger_back_minus_a == zero_96); +} + } // namespace LIBC_NAMESPACE_DECL