Skip to content

Commit

Permalink
Improve BitPackedIntSoA
Browse files Browse the repository at this point in the history
* Test single bit BitPackedIntSoA
* Add fastpath for single bit BitPackedIntSoA
* Catch accidental loss of magnitude for signed integers
* Change BitPackedIntSoA template parameter order. SignBit is used more often, so it should not be the last parameter.
  • Loading branch information
bernhardmgruber committed Jan 16, 2023
1 parent b482f89 commit dc3a3c6
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 17 deletions.
76 changes: 63 additions & 13 deletions include/llama/mapping/BitPackedIntSoA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace llama::mapping
return bits >= sizeof(Integral) * CHAR_BIT ? ~Integral{0} : (Integral{1} << bits) - 1u;
}

template<typename Integral, bool KeepSignBit, typename StoredIntegral>
template<bool KeepSignBit, typename Integral, typename StoredIntegral>
LLAMA_FN_HOST_ACC_INLINE constexpr auto bitunpack(
const StoredIntegral* ptr,
StoredIntegral bitOffset,
Expand Down Expand Up @@ -68,7 +68,7 @@ namespace llama::mapping
return static_cast<Integral>(v);
}

template<typename Integral, bool KeepSignBit, typename StoredIntegral>
template<bool KeepSignBit, typename StoredIntegral, typename Integral>
LLAMA_FN_HOST_ACC_INLINE constexpr void bitpack(
StoredIntegral* ptr,
StoredIntegral bitOffset,
Expand Down Expand Up @@ -123,6 +123,27 @@ namespace llama::mapping
}
}

template<typename Integral, typename StoredIntegral>
LLAMA_FN_HOST_ACC_INLINE constexpr auto bitunpack1(const StoredIntegral* ptr, StoredIntegral bitOffset)
-> Integral
{
constexpr auto bitsPerStoredIntegral = static_cast<StoredIntegral>(sizeof(StoredIntegral) * CHAR_BIT);
const auto bit
= (ptr[bitOffset / bitsPerStoredIntegral] >> (bitOffset % bitsPerStoredIntegral)) & StoredIntegral{1};
return static_cast<Integral>(bit);
}

template<typename StoredIntegral, typename Integral>
LLAMA_FN_HOST_ACC_INLINE constexpr void bitpack1(StoredIntegral* ptr, StoredIntegral bitOffset, Integral value)
{
constexpr auto bitsPerStoredIntegral = static_cast<StoredIntegral>(sizeof(StoredIntegral) * CHAR_BIT);
const auto bitOff = bitOffset % bitsPerStoredIntegral;
auto& dst = ptr[bitOffset / bitsPerStoredIntegral];
dst &= ~(StoredIntegral{1} << bitOff); // clear bit
const auto bit = (static_cast<StoredIntegral>(value) & StoredIntegral{1});
dst |= (bit << bitOff); // set bit
}

/// A proxy type representing a reference to a reduced precision integral value, stored in a buffer at a
/// specified bit offset.
/// @tparam Integral Integral data type which can be loaded and store through this reference.
Expand Down Expand Up @@ -154,15 +175,33 @@ namespace llama::mapping
// NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions)
LLAMA_FN_HOST_ACC_INLINE constexpr operator Integral() const
{
return bitunpack<Integral, SignBit == SignBit::Keep>(
// fast path for single bits without sign handling
if constexpr(std::is_empty_v<VHBits>)
{
if constexpr(VHBits::value() == 1 && (std::is_unsigned_v<Integral> || SignBit == SignBit::Discard))
{
return bitunpack1<Integral>(ptr, static_cast<StoredIntegral>(bitOffset));
}
}

return bitunpack<SignBit == SignBit::Keep, Integral>(
ptr,
static_cast<StoredIntegral>(bitOffset),
static_cast<StoredIntegral>(VHBits::value()));
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator=(Integral value) -> BitPackedIntRef&
{
bitpack<Integral, SignBit == SignBit::Keep>(
// fast path for single bits without sign handling
if constexpr(std::is_empty_v<VHBits>)
{
if constexpr(VHBits::value() == 1 && (std::is_unsigned_v<Integral> || SignBit == SignBit::Discard))
{
bitpack1(ptr, static_cast<StoredIntegral>(bitOffset), value);
}
}

bitpack<SignBit == SignBit::Keep>(
ptr,
static_cast<StoredIntegral>(bitOffset),
static_cast<StoredIntegral>(VHBits::value()),
Expand All @@ -187,19 +226,19 @@ namespace llama::mapping
/// \tparam Bits If Bits is llama::Constant<N>, the compile-time N specifies the number of bits to use. If Bits is
/// an integral type T, the number of bits is specified at runtime, passed to the constructor and stored as type T.
/// Must not be zero and must not be bigger than the bits of TStoredIntegral.
/// @tparam SignBit When set to SignBit::Discard, discards the sign bit when storing signed integers. All
/// numbers will be read back positive.
/// \tparam TLinearizeArrayDimsFunctor Defines how the array dimensions should be mapped into linear numbers and
/// how big the linear domain gets.
/// \tparam TStoredIntegral Integral type used as storage of reduced precision integers. Must be std::uint32_t or
/// std::uint64_t.
/// @tparam SignBit When set to SignBit::Discard, discards the sign bit when storing signed integers. All
/// numbers will be read back positive.
template<
typename TArrayExtents,
typename TRecordDim,
typename Bits = typename TArrayExtents::value_type,
SignBit SignBit = SignBit::Keep,
typename TLinearizeArrayDimsFunctor = LinearizeArrayDimsCpp,
typename TStoredIntegral = internal::StoredUnsignedFor<TRecordDim>,
SignBit SignBit = SignBit::Keep>
typename TStoredIntegral = internal::StoredUnsignedFor<TRecordDim>>
struct BitPackedIntSoA
: MappingBase<TArrayExtents, TRecordDim>
, private llama::internal::BoxedValue<Bits>
Expand Down Expand Up @@ -256,6 +295,11 @@ namespace llama::mapping
static_assert(
static_cast<std::size_t>(VHBits::value()) <= sizeof(FieldType) * CHAR_BIT,
"Storage bits must not be greater than bits of field type");
static_assert(
VHBits::value() >= 2
|| std::is_unsigned_v<FieldType> || SignBit == llama::mapping::SignBit::Discard,
"When keeping the sign bit, Bits must be at least 2 with signed integers in the record "
"dimension");
});
}

Expand All @@ -280,6 +324,11 @@ namespace llama::mapping
if(static_cast<std::size_t>(VHBits::value()) > sizeof(FieldType) * CHAR_BIT)
throw std::invalid_argument(
"BitPackedIntSoA Bits must not be larger than any field type in the record dimension");
if(!(VHBits::value() >= 2
|| std::is_unsigned_v<FieldType> || SignBit == llama::mapping::SignBit::Discard))
throw std::invalid_argument(
"When keeping the sign bit, Bits must be at least 2 with signed integers in the record "
"dimension");
#endif
});
}
Expand Down Expand Up @@ -322,16 +371,17 @@ namespace llama::mapping
/// meta function accepting the latter two. Useful to to prepare this mapping for a meta mapping.
template<
typename Bits = void,
SignBit SignBit = SignBit::Keep,
typename LinearizeArrayDimsFunctor = mapping::LinearizeArrayDimsCpp,
typename StoredIntegral = void,
SignBit SignBit = SignBit::Keep>
typename StoredIntegral = void>
struct BindBitPackedIntSoA
{
template<typename ArrayExtents, typename RecordDim>
using fn = BitPackedIntSoA<
ArrayExtents,
RecordDim,
std::conditional_t<!std::is_void_v<Bits>, Bits, typename ArrayExtents::value_type>,
SignBit,
LinearizeArrayDimsFunctor,
std::conditional_t<
!std::is_void_v<StoredIntegral>,
Expand All @@ -346,9 +396,9 @@ namespace llama::mapping
typename ArrayExtents,
typename RecordDim,
typename Bits,
SignBit SignBit,
typename LinearizeArrayDimsFunctor,
typename StoredIntegral,
SignBit SignBit>
typename StoredIntegral>
inline constexpr bool isBitPackedIntSoA<
BitPackedIntSoA<ArrayExtents, RecordDim, Bits, LinearizeArrayDimsFunctor, StoredIntegral, SignBit>> = true;
BitPackedIntSoA<ArrayExtents, RecordDim, Bits, SignBit, LinearizeArrayDimsFunctor, StoredIntegral>> = true;
} // namespace llama::mapping
97 changes: 93 additions & 4 deletions tests/mapping.BitPackedIntSoA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ TEST_CASE("mapping.BitPackedIntSoA.ValidateBitsSmallerThanStorageIntegral")
llama::ArrayExtents<std::size_t, 16>,
std::uint32_t,
unsigned,
llama::mapping::SignBit::Keep,
llama::mapping::LinearizeArrayDimsCpp,
std::uint32_t>{{}, 40});
}
Expand All @@ -252,6 +253,11 @@ TEST_CASE("mapping.BitPackedIntSoA.ValidateBitsNotZero")
CHECK_THROWS(llama::mapping::BitPackedIntSoA<llama::ArrayExtents<std::size_t, 16>, UInts, unsigned>{{}, 0});
}

TEST_CASE("mapping.BitPackedIntSoA.ValidateBitsAtLeast2WithSignBit")
{
CHECK_THROWS(llama::mapping::BitPackedIntSoA<llama::ArrayExtents<std::size_t, 16>, SInts, unsigned>{{}, 1});
}

TEMPLATE_TEST_CASE(
"mapping.BitPackedIntSoA.bitpack",
"",
Expand All @@ -277,15 +283,15 @@ TEMPLATE_TEST_CASE(
for(StoredIntegral bitCount = 5; bitCount <= sizeof(Integral) * CHAR_BIT; bitCount++)
{
for(Integral i = 0; i < 32; i++)
llama::mapping::internal::bitpack<Integral, false>(
llama::mapping::internal::bitpack<false>(
blob.data(),
static_cast<StoredIntegral>(i * bitCount),
bitCount,
i);

for(Integral i = 0; i < 32; i++)
CHECK(
llama::mapping::internal::bitunpack<Integral, false>(
llama::mapping::internal::bitunpack<false, Integral>(
blob.data(),
static_cast<StoredIntegral>(i * bitCount),
bitCount)
Expand All @@ -298,15 +304,15 @@ TEMPLATE_TEST_CASE(
for(StoredIntegral bitCount = 5 + 1; bitCount <= sizeof(Integral) * CHAR_BIT; bitCount++)
{
for(Integral i = 0; i < 32; i++)
llama::mapping::internal::bitpack<Integral, true>(
llama::mapping::internal::bitpack<true>(
blob.data(),
static_cast<StoredIntegral>(i * bitCount),
bitCount,
i - 32);

for(Integral i = 0; i < 32; i++)
CHECK(
llama::mapping::internal::bitunpack<Integral, true>(
llama::mapping::internal::bitunpack<true, Integral>(
blob.data(),
static_cast<StoredIntegral>(i * bitCount),
bitCount)
Expand All @@ -315,4 +321,87 @@ TEMPLATE_TEST_CASE(
}
}
});
}

TEMPLATE_TEST_CASE(
"mapping.BitPackedIntSoA.bitpack.1bit",
"",
std::int8_t,
std::int16_t,
std::int32_t,
std::int64_t,
std::uint8_t,
std::uint16_t,
std::uint32_t,
std::uint64_t)
{
using Integral = TestType;
boost::mp11::mp_for_each<boost::mp11::mp_list<std::uint32_t, std::uint64_t>>(
[](auto si)
{
using StoredIntegral = decltype(si);
if constexpr(sizeof(StoredIntegral) >= sizeof(TestType))
{
constexpr auto bitsToWrite = 127;
std::vector<StoredIntegral> blob(
llama::divCeil(std::size_t{bitsToWrite}, sizeof(StoredIntegral) * CHAR_BIT));

for(Integral i = 0; i < bitsToWrite; i++)
llama::mapping::internal::bitpack<false>(
blob.data(),
static_cast<StoredIntegral>(i),
StoredIntegral{1},
static_cast<Integral>(i % 2));

for(Integral i = 0; i < bitsToWrite; i++)
CHECK(
llama::mapping::internal::bitunpack<false, Integral>(
blob.data(),
static_cast<StoredIntegral>(i),
StoredIntegral{1})
== static_cast<Integral>(i % 2));
}
});
}

TEMPLATE_TEST_CASE(
"mapping.BitPackedIntSoA.bitpack.1bit.fastpath",
"",
std::int8_t,
std::int16_t,
std::int32_t,
std::int64_t,
std::uint8_t,
std::uint16_t,
std::uint32_t,
std::uint64_t)
{
using Integral = TestType;
boost::mp11::mp_for_each<boost::mp11::mp_list<std::uint32_t, std::uint64_t>>(
[]<typename StoredIntegral>(StoredIntegral)
{
if constexpr(sizeof(StoredIntegral) >= sizeof(TestType))
{
constexpr auto bitsToWrite = 127;
std::vector<StoredIntegral> blob(
llama::divCeil(std::size_t{bitsToWrite}, sizeof(StoredIntegral) * CHAR_BIT));

for(Integral i = 0; i < bitsToWrite; i++)
llama::mapping::internal::bitpack1(
blob.data(),
static_cast<StoredIntegral>(i),
static_cast<Integral>(i % 2));

for(Integral i = 0; i < bitsToWrite; i++)
{
CAPTURE(i);
[[maybe_unused]] auto r
= llama::mapping::internal::bitunpack1<Integral>(blob.data(), static_cast<StoredIntegral>(i));
assert(r == static_cast<Integral>(i % 2));
CHECK(
llama::mapping::internal::bitunpack1<Integral>(blob.data(), static_cast<StoredIntegral>(i))
== static_cast<Integral>(i % 2));
}
}
});
}

0 comments on commit dc3a3c6

Please sign in to comment.