Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various BitPackedIntSoA improvements #677

Merged
merged 1 commit into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 63 additions & 13 deletions include/llama/mapping/BitPackedIntSoA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace llama::mapping
return bits >= sizeof(Integral) * CHAR_BIT ? ~Integral{0} : (Integral{1} << bits) - 1u;
}

template<typename Integral, bool KeepSignBit, typename StoredIntegral>
template<bool KeepSignBit, typename Integral, typename StoredIntegral>
LLAMA_FN_HOST_ACC_INLINE constexpr auto bitunpack(
const StoredIntegral* ptr,
StoredIntegral bitOffset,
Expand Down Expand Up @@ -68,7 +68,7 @@ namespace llama::mapping
return static_cast<Integral>(v);
}

template<typename Integral, bool KeepSignBit, typename StoredIntegral>
template<bool KeepSignBit, typename StoredIntegral, typename Integral>
LLAMA_FN_HOST_ACC_INLINE constexpr void bitpack(
StoredIntegral* ptr,
StoredIntegral bitOffset,
Expand Down Expand Up @@ -123,6 +123,27 @@ namespace llama::mapping
}
}

template<typename Integral, typename StoredIntegral>
LLAMA_FN_HOST_ACC_INLINE constexpr auto bitunpack1(const StoredIntegral* ptr, StoredIntegral bitOffset)
-> Integral
{
constexpr auto bitsPerStoredIntegral = static_cast<StoredIntegral>(sizeof(StoredIntegral) * CHAR_BIT);
const auto bit
= (ptr[bitOffset / bitsPerStoredIntegral] >> (bitOffset % bitsPerStoredIntegral)) & StoredIntegral{1};
return static_cast<Integral>(bit);
}

template<typename StoredIntegral, typename Integral>
LLAMA_FN_HOST_ACC_INLINE constexpr void bitpack1(StoredIntegral* ptr, StoredIntegral bitOffset, Integral value)
{
constexpr auto bitsPerStoredIntegral = static_cast<StoredIntegral>(sizeof(StoredIntegral) * CHAR_BIT);
const auto bitOff = bitOffset % bitsPerStoredIntegral;
auto& dst = ptr[bitOffset / bitsPerStoredIntegral];
dst &= ~(StoredIntegral{1} << bitOff); // clear bit
const auto bit = (static_cast<StoredIntegral>(value) & StoredIntegral{1});
dst |= (bit << bitOff); // set bit
}

/// A proxy type representing a reference to a reduced precision integral value, stored in a buffer at a
/// specified bit offset.
/// @tparam Integral Integral data type which can be loaded and store through this reference.
Expand Down Expand Up @@ -154,15 +175,33 @@ namespace llama::mapping
// NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions)
LLAMA_FN_HOST_ACC_INLINE constexpr operator Integral() const
{
return bitunpack<Integral, SignBit == SignBit::Keep>(
// fast path for single bits without sign handling
if constexpr(std::is_empty_v<VHBits>)
{
if constexpr(VHBits::value() == 1 && (std::is_unsigned_v<Integral> || SignBit == SignBit::Discard))
{
return bitunpack1<Integral>(ptr, static_cast<StoredIntegral>(bitOffset));
}
}

return bitunpack<SignBit == SignBit::Keep, Integral>(
ptr,
static_cast<StoredIntegral>(bitOffset),
static_cast<StoredIntegral>(VHBits::value()));
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator=(Integral value) -> BitPackedIntRef&
{
bitpack<Integral, SignBit == SignBit::Keep>(
// fast path for single bits without sign handling
if constexpr(std::is_empty_v<VHBits>)
{
if constexpr(VHBits::value() == 1 && (std::is_unsigned_v<Integral> || SignBit == SignBit::Discard))
{
bitpack1(ptr, static_cast<StoredIntegral>(bitOffset), value);
}
}

bitpack<SignBit == SignBit::Keep>(
ptr,
static_cast<StoredIntegral>(bitOffset),
static_cast<StoredIntegral>(VHBits::value()),
Expand All @@ -187,19 +226,19 @@ namespace llama::mapping
/// \tparam Bits If Bits is llama::Constant<N>, the compile-time N specifies the number of bits to use. If Bits is
/// an integral type T, the number of bits is specified at runtime, passed to the constructor and stored as type T.
/// Must not be zero and must not be bigger than the bits of TStoredIntegral.
/// @tparam SignBit When set to SignBit::Discard, discards the sign bit when storing signed integers. All
/// numbers will be read back positive.
/// \tparam TLinearizeArrayDimsFunctor Defines how the array dimensions should be mapped into linear numbers and
/// how big the linear domain gets.
/// \tparam TStoredIntegral Integral type used as storage of reduced precision integers. Must be std::uint32_t or
/// std::uint64_t.
/// @tparam SignBit When set to SignBit::Discard, discards the sign bit when storing signed integers. All
/// numbers will be read back positive.
template<
typename TArrayExtents,
typename TRecordDim,
typename Bits = typename TArrayExtents::value_type,
SignBit SignBit = SignBit::Keep,
typename TLinearizeArrayDimsFunctor = LinearizeArrayDimsCpp,
typename TStoredIntegral = internal::StoredUnsignedFor<TRecordDim>,
SignBit SignBit = SignBit::Keep>
typename TStoredIntegral = internal::StoredUnsignedFor<TRecordDim>>
struct BitPackedIntSoA
: MappingBase<TArrayExtents, TRecordDim>
, private llama::internal::BoxedValue<Bits>
Expand Down Expand Up @@ -256,6 +295,11 @@ namespace llama::mapping
static_assert(
static_cast<std::size_t>(VHBits::value()) <= sizeof(FieldType) * CHAR_BIT,
"Storage bits must not be greater than bits of field type");
static_assert(
VHBits::value() >= 2
|| std::is_unsigned_v<FieldType> || SignBit == llama::mapping::SignBit::Discard,
"When keeping the sign bit, Bits must be at least 2 with signed integers in the record "
"dimension");
});
}

Expand All @@ -280,6 +324,11 @@ namespace llama::mapping
if(static_cast<std::size_t>(VHBits::value()) > sizeof(FieldType) * CHAR_BIT)
throw std::invalid_argument(
"BitPackedIntSoA Bits must not be larger than any field type in the record dimension");
if(!(VHBits::value() >= 2
|| std::is_unsigned_v<FieldType> || SignBit == llama::mapping::SignBit::Discard))
throw std::invalid_argument(
"When keeping the sign bit, Bits must be at least 2 with signed integers in the record "
"dimension");
#endif
});
}
Expand Down Expand Up @@ -322,16 +371,17 @@ namespace llama::mapping
/// meta function accepting the latter two. Useful to to prepare this mapping for a meta mapping.
template<
typename Bits = void,
SignBit SignBit = SignBit::Keep,
typename LinearizeArrayDimsFunctor = mapping::LinearizeArrayDimsCpp,
typename StoredIntegral = void,
SignBit SignBit = SignBit::Keep>
typename StoredIntegral = void>
struct BindBitPackedIntSoA
{
template<typename ArrayExtents, typename RecordDim>
using fn = BitPackedIntSoA<
ArrayExtents,
RecordDim,
std::conditional_t<!std::is_void_v<Bits>, Bits, typename ArrayExtents::value_type>,
SignBit,
LinearizeArrayDimsFunctor,
std::conditional_t<
!std::is_void_v<StoredIntegral>,
Expand All @@ -346,9 +396,9 @@ namespace llama::mapping
typename ArrayExtents,
typename RecordDim,
typename Bits,
SignBit SignBit,
typename LinearizeArrayDimsFunctor,
typename StoredIntegral,
SignBit SignBit>
typename StoredIntegral>
inline constexpr bool isBitPackedIntSoA<
BitPackedIntSoA<ArrayExtents, RecordDim, Bits, LinearizeArrayDimsFunctor, StoredIntegral, SignBit>> = true;
BitPackedIntSoA<ArrayExtents, RecordDim, Bits, SignBit, LinearizeArrayDimsFunctor, StoredIntegral>> = true;
} // namespace llama::mapping
97 changes: 93 additions & 4 deletions tests/mapping.BitPackedIntSoA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ TEST_CASE("mapping.BitPackedIntSoA.ValidateBitsSmallerThanStorageIntegral")
llama::ArrayExtents<std::size_t, 16>,
std::uint32_t,
unsigned,
llama::mapping::SignBit::Keep,
llama::mapping::LinearizeArrayDimsCpp,
std::uint32_t>{{}, 40});
}
Expand All @@ -252,6 +253,11 @@ TEST_CASE("mapping.BitPackedIntSoA.ValidateBitsNotZero")
CHECK_THROWS(llama::mapping::BitPackedIntSoA<llama::ArrayExtents<std::size_t, 16>, UInts, unsigned>{{}, 0});
}

TEST_CASE("mapping.BitPackedIntSoA.ValidateBitsAtLeast2WithSignBit")
{
CHECK_THROWS(llama::mapping::BitPackedIntSoA<llama::ArrayExtents<std::size_t, 16>, SInts, unsigned>{{}, 1});
}

TEMPLATE_TEST_CASE(
"mapping.BitPackedIntSoA.bitpack",
"",
Expand All @@ -277,15 +283,15 @@ TEMPLATE_TEST_CASE(
for(StoredIntegral bitCount = 5; bitCount <= sizeof(Integral) * CHAR_BIT; bitCount++)
{
for(Integral i = 0; i < 32; i++)
llama::mapping::internal::bitpack<Integral, false>(
llama::mapping::internal::bitpack<false>(
blob.data(),
static_cast<StoredIntegral>(i * bitCount),
bitCount,
i);

for(Integral i = 0; i < 32; i++)
CHECK(
llama::mapping::internal::bitunpack<Integral, false>(
llama::mapping::internal::bitunpack<false, Integral>(
blob.data(),
static_cast<StoredIntegral>(i * bitCount),
bitCount)
Expand All @@ -298,15 +304,15 @@ TEMPLATE_TEST_CASE(
for(StoredIntegral bitCount = 5 + 1; bitCount <= sizeof(Integral) * CHAR_BIT; bitCount++)
{
for(Integral i = 0; i < 32; i++)
llama::mapping::internal::bitpack<Integral, true>(
llama::mapping::internal::bitpack<true>(
blob.data(),
static_cast<StoredIntegral>(i * bitCount),
bitCount,
i - 32);

for(Integral i = 0; i < 32; i++)
CHECK(
llama::mapping::internal::bitunpack<Integral, true>(
llama::mapping::internal::bitunpack<true, Integral>(
blob.data(),
static_cast<StoredIntegral>(i * bitCount),
bitCount)
Expand All @@ -315,4 +321,87 @@ TEMPLATE_TEST_CASE(
}
}
});
}

TEMPLATE_TEST_CASE(
"mapping.BitPackedIntSoA.bitpack.1bit",
"",
std::int8_t,
std::int16_t,
std::int32_t,
std::int64_t,
std::uint8_t,
std::uint16_t,
std::uint32_t,
std::uint64_t)
{
using Integral = TestType;
boost::mp11::mp_for_each<boost::mp11::mp_list<std::uint32_t, std::uint64_t>>(
[](auto si)
{
using StoredIntegral = decltype(si);
if constexpr(sizeof(StoredIntegral) >= sizeof(TestType))
{
constexpr auto bitsToWrite = 127;
std::vector<StoredIntegral> blob(
llama::divCeil(std::size_t{bitsToWrite}, sizeof(StoredIntegral) * CHAR_BIT));

for(Integral i = 0; i < bitsToWrite; i++)
llama::mapping::internal::bitpack<false>(
blob.data(),
static_cast<StoredIntegral>(i),
StoredIntegral{1},
static_cast<Integral>(i % 2));

for(Integral i = 0; i < bitsToWrite; i++)
CHECK(
llama::mapping::internal::bitunpack<false, Integral>(
blob.data(),
static_cast<StoredIntegral>(i),
StoredIntegral{1})
== static_cast<Integral>(i % 2));
}
});
}

TEMPLATE_TEST_CASE(
"mapping.BitPackedIntSoA.bitpack.1bit.fastpath",
"",
std::int8_t,
std::int16_t,
std::int32_t,
std::int64_t,
std::uint8_t,
std::uint16_t,
std::uint32_t,
std::uint64_t)
{
using Integral = TestType;
boost::mp11::mp_for_each<boost::mp11::mp_list<std::uint32_t, std::uint64_t>>(
[]<typename StoredIntegral>(StoredIntegral)
{
if constexpr(sizeof(StoredIntegral) >= sizeof(TestType))
{
constexpr auto bitsToWrite = 127;
std::vector<StoredIntegral> blob(
llama::divCeil(std::size_t{bitsToWrite}, sizeof(StoredIntegral) * CHAR_BIT));

for(Integral i = 0; i < bitsToWrite; i++)
llama::mapping::internal::bitpack1(
blob.data(),
static_cast<StoredIntegral>(i),
static_cast<Integral>(i % 2));

for(Integral i = 0; i < bitsToWrite; i++)
{
CAPTURE(i);
[[maybe_unused]] auto r
= llama::mapping::internal::bitunpack1<Integral>(blob.data(), static_cast<StoredIntegral>(i));
assert(r == static_cast<Integral>(i % 2));
CHECK(
llama::mapping::internal::bitunpack1<Integral>(blob.data(), static_cast<StoredIntegral>(i))
== static_cast<Integral>(i % 2));
}
}
});
}