Skip to content

Commit

Permalink
Extract bitpacking routines and add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jan 12, 2023
1 parent b70a3f1 commit 4b57cd2
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 92 deletions.
2 changes: 1 addition & 1 deletion include/llama/mapping/BitPackedFloatSoA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ namespace llama::mapping

BitPackedIntRef<
FloatBits,
StoredIntegralPointer,
std::remove_pointer_t<StoredIntegralPointer>,
decltype(integBits(std::declval<VHExp>(), std::declval<VHMan>())),
SizeType>
intref;
Expand Down
196 changes: 105 additions & 91 deletions include/llama/mapping/BitPackedIntSoA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,105 @@ namespace llama::mapping
{
namespace internal
{
template<typename Integral>
LLAMA_FN_HOST_ACC_INLINE constexpr auto makeMask(Integral bits) -> Integral
{
return bits >= sizeof(Integral) * CHAR_BIT ? ~Integral{0} : (Integral{1} << bits) - 1u;
}

template<typename Integral, typename StoredIntegral>
LLAMA_FN_HOST_ACC_INLINE constexpr auto bitunpack(
const StoredIntegral* ptr,
StoredIntegral bitOffset,
StoredIntegral bitCount) -> Integral
{
constexpr auto bitsPerStoredIntegral = static_cast<StoredIntegral>(sizeof(StoredIntegral) * CHAR_BIT);

const auto* p = ptr + bitOffset / bitsPerStoredIntegral;
const auto innerBitOffset = bitOffset % bitsPerStoredIntegral;
// assert(p < endPtr);
auto v = p[0] >> innerBitOffset;

const auto innerBitEndOffset = innerBitOffset + bitCount;
if(innerBitEndOffset <= bitsPerStoredIntegral)
{
const auto mask = makeMask(bitCount);
v &= mask;
}
else
{
const auto excessBits = innerBitEndOffset - bitsPerStoredIntegral;
const auto bitsLoaded = bitsPerStoredIntegral - innerBitOffset;
const auto mask = makeMask(excessBits);
// assert(p + 1 < endPtr);
v |= (p[1] & mask) << bitsLoaded;
}
if constexpr(std::is_signed_v<Integral>)
{
// perform sign extension
if((v & (StoredIntegral{1} << (bitCount - 1))) && bitCount < bitsPerStoredIntegral)
v |= ~StoredIntegral{0} << bitCount;
}
return static_cast<Integral>(v);
}

template<typename Integral, typename StoredIntegral>
LLAMA_FN_HOST_ACC_INLINE constexpr void bitpack(
StoredIntegral* ptr,
StoredIntegral bitOffset,
StoredIntegral bitCount,
Integral value)
{
constexpr auto bitsPerStoredIntegral = static_cast<StoredIntegral>(sizeof(StoredIntegral) * CHAR_BIT);

const auto unsignedValue = static_cast<StoredIntegral>(value);
const auto mask = makeMask(bitCount);
StoredIntegral valueBits;
if constexpr(!std::is_signed_v<Integral>)
valueBits = unsignedValue & mask;
else
{
const auto magnitudeMask = makeMask(bitCount - 1);
const auto isSigned = value < 0;
valueBits = (StoredIntegral{isSigned} << (bitCount - 1)) | (unsignedValue & magnitudeMask);
}

auto* p = ptr + bitOffset / bitsPerStoredIntegral;
const auto innerBitOffset = bitOffset % bitsPerStoredIntegral;

{
const auto clearMask = ~(mask << innerBitOffset);
// assert(p < endPtr);
auto mem = p[0] & clearMask; // clear previous bits
mem |= valueBits << innerBitOffset; // write new bits
p[0] = mem;
}

const auto innerBitEndOffset = innerBitOffset + bitCount;
if(innerBitEndOffset > bitsPerStoredIntegral)
{
const auto excessBits = innerBitEndOffset - bitsPerStoredIntegral;
const auto bitsWritten = bitsPerStoredIntegral - innerBitOffset;
const auto clearMask = ~makeMask(excessBits);
// assert(p + 1 < endPtr);
auto mem = p[1] & clearMask; // clear previous bits
mem |= valueBits >> bitsWritten; // write new bits
p[1] = mem;
}
}

/// A proxy type representing a reference to a reduced precision integral value, stored in a buffer at a
/// specified bit offset.
/// @tparam Integral Integral data type which can be loaded and store through this reference.
/// @tparam StoredIntegralPointer Pointer to integral type used for storing the bits.
template<typename Integral, typename StoredIntegralPointer, typename VHBits, typename SizeType>
/// @tparam StoredIntegralCV Integral type used for storing the bits with CV qualifiers.
/// @tparam SizeType Type used to store sizes and offsets.
template<typename Integral, typename StoredIntegralCV, typename VHBits, typename SizeType>
struct BitPackedIntRef
: private VHBits
, ProxyRefOpMixin<BitPackedIntRef<Integral, StoredIntegralPointer, VHBits, SizeType>, Integral>
, ProxyRefOpMixin<BitPackedIntRef<Integral, StoredIntegralCV, VHBits, SizeType>, Integral>
{
private:
using StoredIntegral = std::remove_const_t<std::remove_pointer_t<StoredIntegralPointer>>;
using StoredIntegral = std::remove_cv_t<StoredIntegralCV>;

static_assert(std::is_integral_v<StoredIntegral>);
static_assert(std::is_unsigned_v<StoredIntegral>);
Expand All @@ -32,112 +120,38 @@ namespace llama::mapping
"The integral type used for the storage must be at least as big as the type of the values to "
"retrieve");

StoredIntegralPointer ptr;
StoredIntegralCV* ptr;
SizeType bitOffset;
#ifndef NDEBUG
StoredIntegralPointer endPtr;
#endif

// NOLINTNEXTLINE(bugprone-misplaced-widening-cast)
static constexpr auto bitsPerStoredIntegral = static_cast<SizeType>(sizeof(StoredIntegral) * CHAR_BIT);

LLAMA_FN_HOST_ACC_INLINE static constexpr auto makeMask(StoredIntegral bits) -> StoredIntegral
{
return bits >= sizeof(StoredIntegral) * CHAR_BIT ? ~StoredIntegral{0}
: (StoredIntegral{1} << bits) - 1u;
}

public:
using value_type = Integral;

LLAMA_FN_HOST_ACC_INLINE constexpr BitPackedIntRef(
StoredIntegralPointer ptr,
StoredIntegralCV* ptr,
SizeType bitOffset,
VHBits vhBits
#ifndef NDEBUG
,
StoredIntegralPointer endPtr
#endif
)
VHBits vhBits)
: VHBits{vhBits}
, ptr{ptr}
, bitOffset{bitOffset}

#ifndef NDEBUG
, endPtr{endPtr}
#endif
{
}

// NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions)
LLAMA_FN_HOST_ACC_INLINE constexpr operator Integral() const
{
auto* p = ptr + bitOffset / bitsPerStoredIntegral;
const auto innerBitOffset = bitOffset % bitsPerStoredIntegral;
assert(p < endPtr);
auto v = p[0] >> innerBitOffset;

const auto innerBitEndOffset = innerBitOffset + VHBits::value();
if(innerBitEndOffset <= bitsPerStoredIntegral)
{
const auto mask = makeMask(VHBits::value());
v &= mask;
}
else
{
const auto excessBits = innerBitEndOffset - bitsPerStoredIntegral;
const auto bitsLoaded = bitsPerStoredIntegral - innerBitOffset;
const auto mask = makeMask(excessBits);
assert(p + 1 < endPtr);
v |= (p[1] & mask) << bitsLoaded;
}
if constexpr(std::is_signed_v<Integral>)
{
// perform sign extension
if((v & (StoredIntegral{1} << (VHBits::value() - 1))) && VHBits::value() < bitsPerStoredIntegral)
v |= ~StoredIntegral{0} << VHBits::value();
}
return static_cast<Integral>(v);
return bitunpack<Integral>(
ptr,
static_cast<StoredIntegral>(bitOffset),
static_cast<StoredIntegral>(VHBits::value()));
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator=(Integral value) -> BitPackedIntRef&
{
// NOLINTNEXTLINE(bugprone-signed-char-misuse,cert-str34-c)
const auto unsignedValue = static_cast<StoredIntegral>(value);
const auto mask = makeMask(VHBits::value());
StoredIntegral valueBits;
if constexpr(!std::is_signed_v<Integral>)
valueBits = unsignedValue & mask;
else
{
const auto magnitudeMask = makeMask(VHBits::value() - 1);
const auto isSigned = value < 0;
valueBits = (StoredIntegral{isSigned} << (VHBits::value() - 1)) | (unsignedValue & magnitudeMask);
}

auto* p = ptr + bitOffset / bitsPerStoredIntegral;
const auto innerBitOffset = bitOffset % bitsPerStoredIntegral;

{
const auto clearMask = ~(mask << innerBitOffset);
assert(p < endPtr);
auto mem = p[0] & clearMask; // clear previous bits
mem |= valueBits << innerBitOffset; // write new bits
p[0] = mem;
}

const auto innerBitEndOffset = innerBitOffset + VHBits::value();
if(innerBitEndOffset > bitsPerStoredIntegral)
{
const auto excessBits = innerBitEndOffset - bitsPerStoredIntegral;
const auto bitsWritten = bitsPerStoredIntegral - innerBitOffset;
const auto clearMask = ~makeMask(excessBits);
assert(p + 1 < endPtr);
auto mem = p[1] & clearMask; // clear previous bits
mem |= valueBits >> bitsWritten; // write new bits
p[1] = mem;
}

bitpack<Integral>(
ptr,
static_cast<StoredIntegral>(bitOffset),
static_cast<StoredIntegral>(VHBits::value()),
value);
return *this;
}
};
Expand Down Expand Up @@ -262,7 +276,7 @@ namespace llama::mapping
using QualifiedStoredIntegral = CopyConst<Blobs, StoredIntegral>;
using DstType = GetType<TRecordDim, RecordCoord<RecordCoords...>>;
LLAMA_BEGIN_SUPPRESS_HOST_DEVICE_WARNING
return internal::BitPackedIntRef<DstType, QualifiedStoredIntegral*, VHBits, size_type>{
return internal::BitPackedIntRef<DstType, QualifiedStoredIntegral, VHBits, size_type>{
reinterpret_cast<QualifiedStoredIntegral*>(&blobs[blob][0]),
bitOffset,
static_cast<const VHBits&>(*this)
Expand Down
14 changes: 14 additions & 0 deletions tests/mapping.BitPackedIntSoA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,18 @@ TEST_CASE("mapping.BitPackedIntSoA.ValidateBitsSmallerThanFieldType")
{
// 11 bits are larger than the uint8_t field type
CHECK_THROWS(llama::mapping::BitPackedIntSoA<llama::ArrayExtents<std::size_t, 16>, UInts, unsigned>{{}, 11});
}

TEMPLATE_TEST_CASE("mapping.BitPackedIntSoA.bitpack.64", "", std::uint32_t, std::uint64_t)
{
using StoredIntegral = TestType;
std::vector<StoredIntegral> blob(sizeof(std::uint32_t) * 64);
for(StoredIntegral bitCount = 6; bitCount <= 32; bitCount++)
{
for(std::uint32_t i = 0; i < 64; i++)
llama::mapping::internal::bitpack<std::uint32_t>(blob.data(), i * bitCount, bitCount, i);

for(std::uint32_t i = 0; i < 64; i++)
CHECK(llama::mapping::internal::bitunpack<std::uint32_t>(blob.data(), i * bitCount, bitCount) == i);
}
}

0 comments on commit 4b57cd2

Please sign in to comment.