Skip to content

Commit

Permalink
Second implementation of the combination of DiscreteElements (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
tpadioleau authored Nov 2, 2023
1 parent c4e1f67 commit 870630b
Show file tree
Hide file tree
Showing 13 changed files with 216 additions and 136 deletions.
8 changes: 4 additions & 4 deletions examples/heat_equation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ int main(int argc, char** argv)
//! [X-domains]
// our zone at the start of the domain that will be mirrored to the
// ghost
ddc::DiscreteDomain const
ddc::DiscreteDomain<DDimX> const
x_domain_begin(x_domain.front(), x_post_ghost.extents());
// our zone at the end of the domain that will be mirrored to the
// ghost
ddc::DiscreteDomain const x_domain_end(
ddc::DiscreteDomain<DDimX> const x_domain_end(
x_domain.back() - x_pre_ghost.extents() + 1,
x_pre_ghost.extents());
//! [X-domains]
Expand All @@ -146,11 +146,11 @@ int main(int argc, char** argv)

// our zone at the start of the domain that will be mirrored to the
// ghost
ddc::DiscreteDomain const
ddc::DiscreteDomain<DDimY> const
y_domain_begin(y_domain.front(), y_post_ghost.extents());
// our zone at the end of the domain that will be mirrored to the
// ghost
ddc::DiscreteDomain const y_domain_end(
ddc::DiscreteDomain<DDimY> const y_domain_end(
y_domain.back() - y_pre_ghost.extents() + 1,
y_pre_ghost.extents());
//! [Y-domains]
Expand Down
58 changes: 19 additions & 39 deletions include/ddc/chunk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,55 +182,35 @@ class Chunk<ElementType, DiscreteDomain<DDims...>, Allocator>
}

/** Element access using a list of DiscreteElement
* @param delems 1D discrete coordinates
* @return const-reference to this element
*/
template <class... ODDims>
element_type const& operator()(DiscreteElement<ODDims> const&... delems) const noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((delems >= front<ODDims>(this->m_domain)) && ...));
assert(((delems <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(take<DDims>(delems...).uid()...);
}

/** Element access using a list of DiscreteElement
* @param delems 1D discrete coordinates
* @return reference to this element
*/
template <class... ODDims>
element_type& operator()(DiscreteElement<ODDims> const&... delems) noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((delems >= front<ODDims>(this->m_domain)) && ...));
assert(((delems <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(take<DDims>(delems...).uid()...);
}

/** Element access using a multi-dimensional DiscreteElement
* @param delems discrete coordinates
* @return const-reference to this element
*/
template <class... ODDims, class = std::enable_if_t<sizeof...(ODDims) != 1>>
element_type const& operator()(DiscreteElement<ODDims...> const& delems) const noexcept
template <class... DElems>
element_type const& operator()(DElems const&... delems) const noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((select<ODDims>(delems) >= front<ODDims>(this->m_domain)) && ...));
assert(((select<ODDims>(delems) <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid<DDims>(delems)...);
static_assert(
sizeof...(DDims) == (0 + ... + DElems::size()),
"Invalid number of dimensions");
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
assert(((select<DDims>(take<DDims>(delems...)) >= front<DDims>(this->m_domain)) && ...));
assert(((select<DDims>(take<DDims>(delems...)) <= back<DDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid<DDims>(take<DDims>(delems...))...);
}

/** Element access using a multi-dimensional DiscreteElement
/** Element access using a list of DiscreteElement
* @param delems discrete coordinates
* @return reference to this element
*/
template <class... ODDims, class = std::enable_if_t<sizeof...(ODDims) != 1>>
element_type& operator()(DiscreteElement<ODDims...> const& delems) noexcept
template <class... DElems>
element_type& operator()(DElems const&... delems) noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((select<ODDims>(delems) >= front<ODDims>(this->m_domain)) && ...));
assert(((select<ODDims>(delems) <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid<DDims>(delems)...);
static_assert(
sizeof...(DDims) == (0 + ... + DElems::size()),
"Invalid number of dimensions");
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
assert(((select<DDims>(take<DDims>(delems...)) >= front<DDims>(this->m_domain)) && ...));
assert(((select<DDims>(take<DDims>(delems...)) <= back<DDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid<DDims>(take<DDims>(delems...))...);
}

/** Returns the label of the Chunk
Expand Down
30 changes: 9 additions & 21 deletions include/ddc/chunk_span.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,31 +252,19 @@ class ChunkSpan<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy, Memo
}

/** Element access using a list of DiscreteElement
* @param delems 1D discrete elements
* @return reference to this element
*/
template <class... ODDims>
KOKKOS_FUNCTION constexpr reference operator()(
DiscreteElement<ODDims> const&... delems) const noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((delems >= front<ODDims>(this->m_domain)) && ...));
assert(((delems <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid(take<DDims>(delems...))...);
}

/** Element access using a multi-dimensional DiscreteElement
* @param delems discrete elements
* @return reference to this element
*/
template <class... ODDims, class = std::enable_if_t<sizeof...(ODDims) != 1>>
KOKKOS_FUNCTION constexpr reference operator()(
DiscreteElement<ODDims...> const& delems) const noexcept
template <class... DElems>
KOKKOS_FUNCTION constexpr reference operator()(DElems const&... delems) const noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((select<ODDims>(delems) >= front<ODDims>(this->m_domain)) && ...));
assert(((select<ODDims>(delems) <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid<DDims>(delems)...);
static_assert(
sizeof...(DDims) == (0 + ... + DElems::size()),
"Invalid number of dimensions");
static_assert((is_discrete_element_v<DElems> && ...), "Expected DiscreteElements");
assert(((select<DDims>(take<DDims>(delems...)) >= front<DDims>(this->m_domain)) && ...));
assert(((select<DDims>(take<DDims>(delems...)) <= back<DDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid<DDims>(take<DDims>(delems...))...);
}

/** Access to the underlying allocation pointer
Expand Down
47 changes: 27 additions & 20 deletions include/ddc/detail/tagged_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ struct IsTaggedVector<TaggedVector<ElementType, Tags...>> : std::true_type
template <class T>
inline constexpr bool is_tagged_vector_v = IsTaggedVector<T>::value;

template <class ElementType, class... Tags>
struct ToTypeSeq<TaggedVector<ElementType, Tags...>>
{
using type = TypeSeq<Tags...>;
};

} // namespace detail

Expand Down Expand Up @@ -186,20 +191,30 @@ KOKKOS_FUNCTION constexpr detail::TaggedVector<ElementType, QueryTags...> select

namespace detail {

template <class QueryTag, class ElementType, class HeadTag, class... TailTags>
KOKKOS_FUNCTION constexpr detail::TaggedVector<ElementType, QueryTag> const& take(
detail::TaggedVector<ElementType, HeadTag> const& head,
detail::TaggedVector<ElementType, TailTags> const&... tags)
/// Returns a reference towards the DiscreteElement that contains the QueryTag
template <
class QueryTag,
class HeadTaggedVector,
class... TailTaggedVectors,
std::enable_if_t<
is_tagged_vector_v<
HeadTaggedVector> && (is_tagged_vector_v<TailTaggedVectors> && ...),
int> = 1>
KOKKOS_FUNCTION constexpr auto const& take(
HeadTaggedVector const& head,
TailTaggedVectors const&... tail)
{
DDC_IF_NVCC_THEN_PUSH_AND_SUPPRESS(implicit_return_from_non_void_function)
if constexpr (std::is_same_v<QueryTag, HeadTag>) {
if constexpr (type_seq_contains_v<detail::TypeSeq<QueryTag>, to_type_seq_t<HeadTaggedVector>>) {
static_assert(
!type_seq_contains_v<detail::TypeSeq<QueryTag>, detail::TypeSeq<TailTags...>>,
(!type_seq_contains_v<
detail::TypeSeq<QueryTag>,
to_type_seq_t<TailTaggedVectors>> && ...),
"ERROR: tag redundant");
return head;
} else {
static_assert(sizeof...(TailTags) > 0, "ERROR: tag not found");
return take<QueryTag>(tags...);
static_assert(sizeof...(TailTaggedVectors) > 0, "ERROR: tag not found");
return take<QueryTag>(tail...);
}
DDC_IF_NVCC_THEN_POP
}
Expand Down Expand Up @@ -250,24 +265,16 @@ class TaggedVector : public ConversionOperators<TaggedVector<ElementType, Tags..

KOKKOS_DEFAULTED_FUNCTION constexpr TaggedVector(TaggedVector&&) = default;

template <class... OTags>
KOKKOS_FUNCTION constexpr TaggedVector(
TaggedVector<ElementType, OTags> const&... other) noexcept
: m_values {take<Tags>(other...).value()...}
{
}

template <class OElementType, class... OTags>
explicit KOKKOS_FUNCTION constexpr TaggedVector(
TaggedVector<OElementType, OTags...> const& other) noexcept
: m_values {(static_cast<ElementType>(other.template get<Tags>()))...}
template <class... TVectors, class = std::enable_if_t<(is_tagged_vector_v<TVectors> && ...)>>
explicit KOKKOS_FUNCTION constexpr TaggedVector(TVectors const&... delems) noexcept
: m_values {static_cast<ElementType>(take<Tags>(delems...).template get<Tags>())...}
{
}

template <
class... Params,
class = std::enable_if_t<(std::is_convertible_v<Params, ElementType> && ...)>,
class = std::enable_if_t<(!is_tagged_vector_v<Params> && ...)>,
class = std::enable_if_t<(std::is_convertible_v<Params, ElementType> && ...)>,
class = std::enable_if_t<sizeof...(Params) == sizeof...(Tags)>>
explicit KOKKOS_FUNCTION constexpr TaggedVector(Params const&... params) noexcept
: m_values {static_cast<ElementType>(params)...}
Expand Down
7 changes: 7 additions & 0 deletions include/ddc/detail/type_seq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ struct TypeSeqReplace<
{
};

template <class T>
struct ToTypeSeq;

} // namespace detail

template <class QueryTag, class TypeSeq>
Expand Down Expand Up @@ -192,4 +195,8 @@ using type_seq_merge_t = typename detail::TypeSeqMerge<TagSeqA, TagSeqB, TagSeqA
template <class TagSeqA, class TagSeqB, class TagSeqC>
using type_seq_replace_t =
typename detail::TypeSeqReplace<TagSeqA, TagSeqB, TagSeqC, detail::TypeSeq<>>::type;

template <class T>
using to_type_seq_t = typename detail::ToTypeSeq<T>::type;

} // namespace ddc
39 changes: 27 additions & 12 deletions include/ddc/discrete_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@ struct DiscreteDomainIterator;
template <class... DDims>
class DiscreteDomain;

template <class T>
struct IsDiscreteDomain : std::false_type
{
};

template <class... Tags>
struct IsDiscreteDomain<DiscreteDomain<Tags...>> : std::true_type
{
};

template <class T>
inline constexpr bool is_discrete_domain_v = IsDiscreteDomain<T>::value;


namespace detail {

template <class... Tags>
struct ToTypeSeq<DiscreteDomain<Tags...>>
{
using type = TypeSeq<Tags...>;
};

} // namespace detail

template <class... DDims>
class DiscreteDomain
{
Expand All @@ -43,18 +67,9 @@ class DiscreteDomain

KOKKOS_DEFAULTED_FUNCTION DiscreteDomain() = default;

/// Construct a DiscreteDomain from a reordered copy of `domain`
template <class... ODDims>
explicit KOKKOS_FUNCTION constexpr DiscreteDomain(DiscreteDomain<ODDims...> const& domain)
: m_element_begin(domain.front())
, m_element_end(domain.front() + domain.extents())
{
}

// Use SFINAE to disambiguate with the copy constructor.
// Note that SFINAE may be redundant because a template constructor should not be selected as a copy constructor.
template <std::size_t N = sizeof...(DDims), class = std::enable_if_t<(N != 1)>>
explicit KOKKOS_FUNCTION constexpr DiscreteDomain(DiscreteDomain<DDims> const&... domains)
/// Construct a DiscreteDomain by copies and merge of domains
template <class... DDoms, class = std::enable_if_t<(is_discrete_domain_v<DDoms> && ...)>>
explicit KOKKOS_FUNCTION constexpr DiscreteDomain(DDoms const&... domains)
: m_element_begin(domains.front()...)
, m_element_end((domains.front() + domains.extents())...)
{
Expand Down
49 changes: 29 additions & 20 deletions include/ddc/discrete_element.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
#include <array>
#include <cstddef>
#include <ostream>
#include <type_traits>
#include <utility>

#include "ddc/coordinate.hpp"
#include "ddc/detail/macros.hpp"
#include "ddc/detail/type_seq.hpp"
#include "ddc/discrete_vector.hpp"

Expand All @@ -30,6 +32,16 @@ template <class T>
inline constexpr bool is_discrete_element_v = IsDiscreteElement<T>::value;


namespace detail {

template <class... Tags>
struct ToTypeSeq<DiscreteElement<Tags...>>
{
using type = TypeSeq<Tags...>;
};

} // namespace detail

/** A DiscreteCoordElement is a scalar that identifies an element of the discrete dimension
*/
using DiscreteElementType = std::size_t;
Expand Down Expand Up @@ -81,20 +93,25 @@ KOKKOS_FUNCTION constexpr DiscreteElement<QueryTags...> select(
return DiscreteElement<QueryTags...>(std::move(arr));
}

template <class QueryTag, class HeadTag, class... TailTags>
KOKKOS_FUNCTION constexpr DiscreteElement<QueryTag> const& take(
DiscreteElement<HeadTag> const& head,
DiscreteElement<TailTags> const&... tags)
/// Returns a reference towards the DiscreteElement that contains the QueryTag
template <
class QueryTag,
class HeadDElem,
class... TailDElems,
std::enable_if_t<
is_discrete_element_v<HeadDElem> && (is_discrete_element_v<TailDElems> && ...),
int> = 1>
KOKKOS_FUNCTION constexpr auto const& take(HeadDElem const& head, TailDElems const&... tail)
{
DDC_IF_NVCC_THEN_PUSH_AND_SUPPRESS(implicit_return_from_non_void_function)
if constexpr (std::is_same_v<QueryTag, HeadTag>) {
if constexpr (type_seq_contains_v<detail::TypeSeq<QueryTag>, to_type_seq_t<HeadDElem>>) {
static_assert(
!type_seq_contains_v<detail::TypeSeq<QueryTag>, detail::TypeSeq<TailTags...>>,
(!type_seq_contains_v<detail::TypeSeq<QueryTag>, to_type_seq_t<TailDElems>> && ...),
"ERROR: tag redundant");
return head;
} else {
static_assert(sizeof...(TailTags) > 0, "ERROR: tag not found");
return take<QueryTag>(tags...);
static_assert(sizeof...(TailDElems) > 0, "ERROR: tag not found");
return take<QueryTag>(tail...);
}
DDC_IF_NVCC_THEN_POP
}
Expand Down Expand Up @@ -152,24 +169,16 @@ class DiscreteElement

KOKKOS_DEFAULTED_FUNCTION constexpr DiscreteElement(DiscreteElement&&) = default;

template <class... OTags>
explicit KOKKOS_FUNCTION constexpr DiscreteElement(
DiscreteElement<OTags> const&... other) noexcept
: m_values {take<Tags>(other...).uid()...}
{
}

template <class... OTags>
explicit KOKKOS_FUNCTION constexpr DiscreteElement(
DiscreteElement<OTags...> const& other) noexcept
: m_values {other.template uid<Tags>()...}
template <class... DElems, class = std::enable_if_t<(is_discrete_element_v<DElems> && ...)>>
explicit KOKKOS_FUNCTION constexpr DiscreteElement(DElems const&... delems) noexcept
: m_values {take<Tags>(delems...).template uid<Tags>()...}
{
}

template <
class... Params,
class = std::enable_if_t<(std::is_integral_v<Params> && ...)>,
class = std::enable_if_t<(!is_discrete_element_v<Params> && ...)>,
class = std::enable_if_t<(std::is_integral_v<Params> && ...)>,
class = std::enable_if_t<sizeof...(Params) == sizeof...(Tags)>>
explicit KOKKOS_FUNCTION constexpr DiscreteElement(Params const&... params) noexcept
: m_values {static_cast<value_type>(params)...}
Expand Down
Loading

0 comments on commit 870630b

Please sign in to comment.