Skip to content

Commit

Permalink
add LeafRecordCoords and make forEachLeafImpl use it
Browse files Browse the repository at this point in the history
LeafRecordCoords creates a flat type list of the record coordinates of the leaves.
forEachLeafImpl uses the flat type list to save template instantiations.
  • Loading branch information
bernhardmgruber committed May 19, 2021
1 parent 75f2c0b commit 1f263ec
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 27 deletions.
48 changes: 22 additions & 26 deletions include/llama/Core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,34 +263,32 @@ namespace llama

namespace internal
{
template <typename T, std::size_t... Coords, typename Functor>
LLAMA_FN_HOST_ACC_INLINE constexpr void forEachLeafImpl(T*, RecordCoord<Coords...> coord, Functor&& functor)
template <typename RecordDim, typename RecordCoord>
struct LeafRecordCoordsImpl;

template <typename T, std::size_t... RCs>
struct LeafRecordCoordsImpl<T, RecordCoord<RCs...>>
{
functor(coord);
using type = boost::mp11::mp_list<RecordCoord<RCs...>>;
};

template <typename... Children, std::size_t... Coords, typename Functor>
LLAMA_FN_HOST_ACC_INLINE constexpr void forEachLeafImpl(
Record<Children...>*,
RecordCoord<Coords...>,
Functor&& functor)
template <typename... Fields, std::size_t... RCs>
struct LeafRecordCoordsImpl<Record<Fields...>, RecordCoord<RCs...>>
{
LLAMA_FORCE_INLINE_RECURSIVE
boost::mp11::mp_for_each<boost::mp11::mp_iota_c<sizeof...(Children)>>(
[&](auto i)
{
constexpr auto childIndex = decltype(i)::value;
using Field = boost::mp11::mp_at_c<Record<Children...>, childIndex>;

LLAMA_FORCE_INLINE_RECURSIVE
forEachLeafImpl(
static_cast<GetFieldType<Field>*>(nullptr),
RecordCoord<Coords..., childIndex>{},
std::forward<Functor>(functor));
});
}
template <std::size_t... Is>
static auto help(std::index_sequence<Is...>)
{
return boost::mp11::mp_append<
typename LeafRecordCoordsImpl<GetFieldType<Fields>, RecordCoord<RCs..., Is>>::type...>{};
}
using type = decltype(help(std::make_index_sequence<sizeof...(Fields)>{}));
};
} // namespace internal

/// Returns a flat type list containing all record coordinates to all leaves of the given record dimension.
template <typename RecordDim>
using LeafRecordCoords = typename internal::LeafRecordCoordsImpl<RecordDim, RecordCoord<>>::type;

/// Iterates over the record dimension tree and calls a functor on each element.
/// \param functor Functor to execute at each element of. Needs to have
/// `operator()` with a template parameter for the \ref RecordCoord in the
Expand All @@ -301,10 +299,8 @@ namespace llama
LLAMA_FN_HOST_ACC_INLINE constexpr void forEachLeaf(Functor&& functor, RecordCoord<Coords...> baseCoord)
{
LLAMA_FORCE_INLINE_RECURSIVE
internal::forEachLeafImpl(
static_cast<GetType<RecordDim, RecordCoord<Coords...>>*>(nullptr),
baseCoord,
std::forward<Functor>(functor));
boost::mp11::mp_for_each<LeafRecordCoords<GetType<RecordDim, RecordCoord<Coords...>>>>([&](
auto innerCoord) constexpr { functor(cat(baseCoord, innerCoord)); });
}

/// Iterates over the record dimension tree and calls a functor on each element.
Expand Down
2 changes: 1 addition & 1 deletion include/llama/RecordCoord.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ namespace llama

/// Concatenate two \ref RecordCoord instances.
template <typename RecordCoord1, typename RecordCoord2>
auto cat(RecordCoord1, RecordCoord2)
constexpr auto cat(RecordCoord1, RecordCoord2)
{
return Cat<RecordCoord1, RecordCoord2>{};
}
Expand Down

0 comments on commit 1f263ec

Please sign in to comment.