Skip to content

Commit

Permalink
add ProxyRefOpMixin CRTP mixing for proxy references
Browse files Browse the repository at this point in the history
* supplies compound assignment and increment/decrement operators
* mixin ProxyRefOpMixin into proxy references for ChangeTypeReference, BitPackedFloatRef, BitPackedIntRef and Bytesplit::Reference
* add unit tests for ProxyRefOpMixin and proxy references of mappings

Fixes alpaka-group#429
  • Loading branch information
bernhardmgruber committed Dec 10, 2021
1 parent 727f363 commit ec4fe12
Show file tree
Hide file tree
Showing 7 changed files with 473 additions and 18 deletions.
144 changes: 144 additions & 0 deletions include/llama/ProxyRefOpMixin.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// SPDX-License-Identifier: GPL-3.0-or-later

#pragma once

#include "Concepts.hpp"
#include "macros.hpp"

namespace llama
{
/// CRTP mixin for proxy reference types to support all compound assignment and increment/decrement operators.
template<typename Derived, typename ValueType>
struct ProxyRefOpMixin
{
private:
LLAMA_FN_HOST_ACC_INLINE constexpr auto derived() -> Derived&
{
return static_cast<Derived&>(*this);
}

// in principle, load() could be const, but we use it only from non-const functions
LLAMA_FN_HOST_ACC_INLINE constexpr auto load() -> ValueType
{
return static_cast<ValueType>(derived());
}

LLAMA_FN_HOST_ACC_INLINE constexpr void store(ValueType t)
{
derived() = std::move(t);
}

public:
LLAMA_FN_HOST_ACC_INLINE constexpr auto operator+=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs += rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator-=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs -= rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator*=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs *= rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator/=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs /= rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator%=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs %= rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator<<=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs <<= rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator>>=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs >>= rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator&=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs &= rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator|=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs |= rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator^=(const ValueType& rhs) -> Derived&
{
ValueType lhs = load();
lhs ^= rhs;
store(lhs);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator++() -> Derived&
{
ValueType v = load();
++v;
store(v);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator++(int) -> ValueType
{
ValueType v = load();
ValueType old = v++;
store(v);
return old;
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator--() -> Derived&
{
ValueType v = load();
--v;
store(v);
return derived();
}

LLAMA_FN_HOST_ACC_INLINE constexpr auto operator--(int) -> ValueType
{
ValueType v = load();
ValueType old = v--;
store(v);
return old;
}
};
} // namespace llama
1 change: 1 addition & 0 deletions include/llama/llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "Meta.hpp"
#include "Vector.hpp"
#include "View.hpp"
#include "ProxyRefOpMixin.hpp"
#include "VirtualRecord.hpp"
#include "macros.hpp"
#include "mapping/AoS.hpp"
Expand Down
13 changes: 8 additions & 5 deletions include/llama/mapping/BitPackedFloatSoA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include "../ProxyRefOpMixin.hpp"
#include "BitPackedIntSoA.hpp"

#include <climits>
Expand Down Expand Up @@ -82,8 +83,9 @@ namespace llama::mapping
/// @tparam Integral Integral data type which can be loaded and store through this reference.
/// @tparam StoredIntegralPointer Pointer to integral type used for storing the bits.
template<typename Float, typename StoredIntegralPointer>
struct BitPackedFloatRef
struct BitPackedFloatRef : ProxyRefOpMixin<BitPackedFloatRef<Float, StoredIntegralPointer>, Float>
{
private:
static_assert(
std::is_same_v<Float, float> || std::is_same_v<Float, double>,
"Types other than float or double are not implemented yet");
Expand All @@ -93,13 +95,14 @@ namespace llama::mapping

using FloatBits = std::conditional_t<std::is_same_v<Float, float>, std::uint32_t, std::uint64_t>;

private:
internal::BitPackedIntRef<FloatBits, StoredIntegralPointer> intref;
unsigned exponentBits = 0;
unsigned mantissaBits = 0;

public:
BitPackedFloatRef(
using value_type = Float;

LLAMA_FN_HOST_ACC_INLINE constexpr BitPackedFloatRef(
StoredIntegralPointer p,
std::size_t bitOffset,
unsigned exponentBits,
Expand All @@ -120,7 +123,7 @@ namespace llama::mapping
}

// NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions)
operator Float() const
LLAMA_FN_HOST_ACC_INLINE constexpr operator Float() const
{
using Bits = internal::FloatBitTraits<Float>;
const FloatBits packedFloat = intref;
Expand All @@ -131,7 +134,7 @@ namespace llama::mapping
return f;
}

auto operator=(Float f) -> BitPackedFloatRef&
LLAMA_FN_HOST_ACC_INLINE constexpr auto operator=(Float f) -> BitPackedFloatRef&
{
using Bits = internal::FloatBitTraits<Float>;
FloatBits unpackedFloat = 0;
Expand Down
36 changes: 33 additions & 3 deletions include/llama/mapping/BitPackedIntSoA.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#pragma once

#include "../ProxyRefOpMixin.hpp"

#include <climits>
#include <type_traits>

Expand All @@ -14,8 +16,9 @@ namespace llama::mapping
/// @tparam Integral Integral data type which can be loaded and store through this reference.
/// @tparam StoredIntegralPointer Pointer to integral type used for storing the bits.
template<typename Integral, typename StoredIntegralPointer>
struct BitPackedIntRef
struct BitPackedIntRef : ProxyRefOpMixin<BitPackedIntRef<Integral, StoredIntegralPointer>, Integral>
{
private:
using StoredIntegral = std::remove_const_t<std::remove_pointer_t<StoredIntegralPointer>>;

static_assert(std::is_integral_v<Integral>);
Expand All @@ -35,8 +38,35 @@ namespace llama::mapping

static constexpr auto bitsPerStoredIntegral = sizeof(StoredIntegral) * CHAR_BIT;

public:
using value_type = Integral;

LLAMA_FN_HOST_ACC_INLINE constexpr BitPackedIntRef(
StoredIntegralPointer ptr,
std::size_t bitOffset,
unsigned bits
#ifndef NDEBUG
,
StoredIntegralPointer endPtr
#endif
)
: ptr{ptr}
, bitOffset{bitOffset}
, bits
{
bits
}
#ifndef NDEBUG
endPtr
{
endPtr
}
#endif
{
}

// NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions)
operator Integral() const
LLAMA_FN_HOST_ACC_INLINE constexpr operator Integral() const
{
auto* p = ptr + bitOffset / bitsPerStoredIntegral;
const auto innerBitOffset = bitOffset % bitsPerStoredIntegral;
Expand All @@ -63,7 +93,7 @@ namespace llama::mapping
return static_cast<Integral>(v);
}

auto operator=(Integral value) -> BitPackedIntRef&
LLAMA_FN_HOST_ACC_INLINE constexpr auto operator=(Integral value) -> BitPackedIntRef&
{
const auto unsignedValue = static_cast<StoredIntegral>(value);
const auto mask = (StoredIntegral{1} << bits) - 1u;
Expand Down
25 changes: 18 additions & 7 deletions include/llama/mapping/Bytesplit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include "../ProxyRefOpMixin.hpp"
#include "Common.hpp"

namespace llama::mapping
Expand All @@ -15,6 +16,8 @@ namespace llama::mapping
using SplitBytes = TransformLeaves<RecordDim, ReplaceByByteArray>;
} // namespace internal

/// Meta mapping splitting each field in the record dimension into an array of bytes and mapping the resulting
/// record dimension using a further mapping.
template<typename TArrayExtents, typename TRecordDim, template<typename, typename> typename InnerMapping>
struct Bytesplit : private InnerMapping<TArrayExtents, internal::SplitBytes<TRecordDim>>
{
Expand All @@ -41,20 +44,28 @@ namespace llama::mapping
}

template<typename QualifiedBase, typename RC, typename BlobArray>
struct Reference
struct Reference : ProxyRefOpMixin<Reference<QualifiedBase, RC, BlobArray>, GetType<TRecordDim, RC>>
{
QualifiedBase& innerMapping;
ArrayIndex ai;
BlobArray& blobs;

using DstType = GetType<TRecordDim, RC>;
public:
using value_type = GetType<TRecordDim, RC>;

Reference(QualifiedBase& innerMapping, ArrayIndex ai, BlobArray& blobs)
: innerMapping(innerMapping)
, ai(ai)
, blobs(blobs)
{
}

// NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions)
operator DstType() const
operator value_type() const
{
DstType v;
value_type v;
auto* p = reinterpret_cast<std::byte*>(&v);
boost::mp11::mp_for_each<boost::mp11::mp_iota_c<sizeof(DstType)>>(
boost::mp11::mp_for_each<boost::mp11::mp_iota_c<sizeof(value_type)>>(
[&](auto ic)
{
constexpr auto i = decltype(ic)::value;
Expand All @@ -64,10 +75,10 @@ namespace llama::mapping
return v;
}

auto operator=(DstType v) -> Reference&
auto operator=(value_type v) -> Reference&
{
auto* p = reinterpret_cast<std::byte*>(&v);
boost::mp11::mp_for_each<boost::mp11::mp_iota_c<sizeof(DstType)>>(
boost::mp11::mp_for_each<boost::mp11::mp_iota_c<sizeof(value_type)>>(
[&](auto ic)
{
constexpr auto i = decltype(ic)::value;
Expand Down
15 changes: 12 additions & 3 deletions include/llama/mapping/ChangeType.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#pragma once

#include "../ProxyRefOpMixin.hpp"
#include "Common.hpp"

namespace llama::mapping
Expand All @@ -22,17 +23,25 @@ namespace llama::mapping
using ReplaceType = TransformLeaves<RecordDim, MakeReplacer<ReplacementMap>::template type>;

template<typename UserT, typename StoredT>
struct ChangeTypeReference
struct ChangeTypeReference : ProxyRefOpMixin<ChangeTypeReference<UserT, StoredT>, UserT>
{
private:
StoredT& storageRef;

public:
using value_type = UserT;

LLAMA_FN_HOST_ACC_INLINE constexpr ChangeTypeReference(StoredT& storageRef) : storageRef{storageRef}
{
}

// NOLINTNEXTLINE(google-explicit-constructor,hicpp-explicit-conversions)
operator UserT() const
LLAMA_FN_HOST_ACC_INLINE constexpr operator UserT() const
{
return storageRef;
}

auto operator=(UserT v) -> ChangeTypeReference&
LLAMA_FN_HOST_ACC_INLINE constexpr auto operator=(UserT v) -> ChangeTypeReference&
{
storageRef = v;
return *this;
Expand Down
Loading

0 comments on commit ec4fe12

Please sign in to comment.