Skip to content

Commit

Permalink
Add accessors to Views
Browse files Browse the repository at this point in the history
* Add a new template parameter to View
* Add a new DefaultAccessor, which just passes through the reference

Fixes: #523
  • Loading branch information
bernhardmgruber committed Sep 24, 2022
1 parent e1e3ebe commit b1d10bb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 32 deletions.
4 changes: 2 additions & 2 deletions include/llama/RecordRef.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ namespace llama
else
{
LLAMA_FORCE_INLINE_RECURSIVE
return this->view.accessor(arrayIndex(), AbsolutCoord{});
return this->view.access(arrayIndex(), AbsolutCoord{});
}
}

Expand All @@ -454,7 +454,7 @@ namespace llama
else
{
LLAMA_FORCE_INLINE_RECURSIVE
return this->view.accessor(arrayIndex(), AbsolutCoord{});
return this->view.access(arrayIndex(), AbsolutCoord{});
}
}

Expand Down
69 changes: 39 additions & 30 deletions include/llama/View.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,20 @@

namespace llama
{
/// Default accessor. Passes through the given reference.
struct DefaultAccessor
{
template<typename Reference>
LLAMA_FN_HOST_ACC_INLINE auto operator()(Reference&& r) const -> Reference
{
return std::forward<Reference>(r);
}
};

#ifdef __cpp_lib_concepts
template<typename TMapping, Blob BlobType>
template<typename TMapping, Blob BlobType, typename TAccessor = DefaultAccessor>
#else
template<typename TMapping, typename BlobType>
template<typename TMapping, typename BlobType, typename TAccessor = DefaultAccessor>
#endif
struct View;

Expand Down Expand Up @@ -335,13 +345,15 @@ namespace llama
/// view should be created using \ref allocView.
/// \tparam TMapping The mapping used by the view to map accesses into memory.
/// \tparam BlobType The storage type used by the view holding memory.
/// \tparam TAccessor The accessor to use when an access is made through this view.
#ifdef __cpp_lib_concepts
template<typename TMapping, Blob BlobType>
template<typename TMapping, Blob BlobType, typename TAccessor>
#else
template<typename TMapping, typename BlobType>
template<typename TMapping, typename BlobType, typename TAccessor>
#endif
struct LLAMA_DECLSPEC_EMPTY_BASES View
: private TMapping
, private TAccessor
#if CAN_USE_RANGES
, std::ranges::view_base
#endif
Expand All @@ -351,6 +363,7 @@ namespace llama
using ArrayExtents = typename Mapping::ArrayExtents;
using ArrayIndex = typename Mapping::ArrayIndex;
using RecordDim = typename Mapping::RecordDim;
using Accessor = TAccessor;
using iterator = Iterator<View>;
using const_iterator = Iterator<const View>;
using size_type = typename ArrayExtents::value_type;
Expand Down Expand Up @@ -385,6 +398,16 @@ namespace llama
return static_cast<const Mapping&>(*this);
}

LLAMA_FN_HOST_ACC_INLINE auto accessor() -> Accessor&
{
return static_cast<Accessor&>(*this);
}

LLAMA_FN_HOST_ACC_INLINE auto accessor() const -> const Accessor&
{
return static_cast<const Accessor&>(*this);
}

#if !(defined(_MSC_VER) && defined(__NVCC__))
template<typename V>
auto operator()(llama::ArrayIndex<V, ArrayIndex::rank>) const
Expand All @@ -404,7 +427,7 @@ namespace llama
else
{
LLAMA_FORCE_INLINE_RECURSIVE
return accessor(ai, RecordCoord<>{});
return access(ai, RecordCoord<>{});
}
}

Expand All @@ -418,7 +441,7 @@ namespace llama
else
{
LLAMA_FORCE_INLINE_RECURSIVE
return accessor(ai, RecordCoord<>{});
return access(ai, RecordCoord<>{});
}
}

Expand Down Expand Up @@ -515,16 +538,16 @@ namespace llama

LLAMA_SUPPRESS_HOST_DEVICE_WARNING
template<std::size_t... Coords>
LLAMA_FN_HOST_ACC_INLINE auto accessor(ArrayIndex ai, RecordCoord<Coords...> rc = {}) const -> decltype(auto)
LLAMA_FN_HOST_ACC_INLINE auto access(ArrayIndex ai, RecordCoord<Coords...> rc = {}) const -> decltype(auto)
{
return mapToMemory(mapping(), ai, rc, storageBlobs);
return accessor()(mapToMemory(mapping(), ai, rc, storageBlobs));
}

LLAMA_SUPPRESS_HOST_DEVICE_WARNING
template<std::size_t... Coords>
LLAMA_FN_HOST_ACC_INLINE auto accessor(ArrayIndex ai, RecordCoord<Coords...> rc = {}) -> decltype(auto)
LLAMA_FN_HOST_ACC_INLINE auto access(ArrayIndex ai, RecordCoord<Coords...> rc = {}) -> decltype(auto)
{
return mapToMemory(mapping(), ai, rc, storageBlobs);
return accessor()(mapToMemory(mapping(), ai, rc, storageBlobs));
}
};

Expand Down Expand Up @@ -587,18 +610,6 @@ namespace llama
{
}

template<std::size_t... Coords>
LLAMA_FN_HOST_ACC_INLINE auto accessor(ArrayIndex ai) const -> const auto&
{
return parentView.template accessor<Coords...>(ArrayIndex{ai + offset});
}

template<std::size_t... Coords>
LLAMA_FN_HOST_ACC_INLINE auto accessor(ArrayIndex ai) -> auto&
{
return parentView.template accessor<Coords...>(ArrayIndex{ai + offset});
}

/// Same as \ref View::operator()(ArrayIndex), but shifted by the offset of this \ref VirtualView.
LLAMA_FN_HOST_ACC_INLINE auto operator()(ArrayIndex ai) const -> decltype(auto)
{
Expand Down Expand Up @@ -641,18 +652,16 @@ namespace llama
ArrayIndex{ArrayIndex{static_cast<typename ArrayIndex::value_type>(indices)...} + offset});
}

template<std::size_t... Coord>
LLAMA_FN_HOST_ACC_INLINE auto operator()(RecordCoord<Coord...> = {}) const -> decltype(auto)
template<std::size_t... Coords>
LLAMA_FN_HOST_ACC_INLINE auto operator()(RecordCoord<Coords...> rc = {}) const -> decltype(auto)
{
LLAMA_FORCE_INLINE_RECURSIVE
return accessor<Coord...>(ArrayIndex{});
return parentView(ArrayIndex{} + offset, rc);
}

template<std::size_t... Coord>
LLAMA_FN_HOST_ACC_INLINE auto operator()(RecordCoord<Coord...> = {}) -> decltype(auto)
template<std::size_t... Coords>
LLAMA_FN_HOST_ACC_INLINE auto operator()(RecordCoord<Coords...> rc = {}) -> decltype(auto)
{
LLAMA_FORCE_INLINE_RECURSIVE
return accessor<Coord...>(ArrayIndex{});
return parentView(ArrayIndex{} + offset, rc);
}

StoredParentView parentView;
Expand Down

0 comments on commit b1d10bb

Please sign in to comment.