Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce DiscreteElement class #53

Merged
merged 17 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions include/ddc/chunk.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,58 +161,52 @@ class Chunk<ElementType, DiscreteDomain<DDims...>, Allocator>
* @param mcoords 1D discrete coordinates
* @return const-reference to this element
*/
// Warning: Do not use DiscreteCoordinate because of template deduction issue with clang 12
template <class... ODDims>
element_type const& operator()(
detail::TaggedVector<DiscreteCoordElement, ODDims> const&... mcoords) const noexcept
element_type const& operator()(DiscreteCoordinate<ODDims> const&... mcoords) const noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((mcoords >= front<ODDims>(this->m_domain)) && ...));
assert(((mcoords <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(take<DDims>(mcoords...)...);
return this->m_internal_mdspan(take<DDims>(mcoords...).uid()...);
}

/** Element access using a list of DiscreteCoordinate
* @param mcoords 1D discrete coordinates
* @return reference to this element
*/
// Warning: Do not use DiscreteCoordinate because of template deduction issue with clang 12
template <class... ODDims>
element_type& operator()(
detail::TaggedVector<DiscreteCoordElement, ODDims> const&... mcoords) noexcept
element_type& operator()(DiscreteCoordinate<ODDims> const&... mcoords) noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((mcoords >= front<ODDims>(this->m_domain)) && ...));
assert(((mcoords <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(take<DDims>(mcoords...)...);
return this->m_internal_mdspan(take<DDims>(mcoords...).uid()...);
}

/** Element access using a multi-dimensional DiscreteCoordinate
* @param mcoord discrete coordinates
* @return const-reference to this element
*/
template <class... ODDims, class = std::enable_if_t<sizeof...(ODDims) != 1>>
element_type const& operator()(
detail::TaggedVector<DiscreteCoordElement, ODDims...> const& mcoord) const noexcept
element_type const& operator()(DiscreteCoordinate<ODDims...> const& mcoord) const noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((get<ODDims>(mcoord) >= front<ODDims>(this->m_domain)) && ...));
assert(((get<ODDims>(mcoord) <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(get<DDims>(mcoord)...);
assert(((select<ODDims>(mcoord) >= front<ODDims>(this->m_domain)) && ...));
assert(((select<ODDims>(mcoord) <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid<DDims>(mcoord)...);
}

/** Element access using a multi-dimensional DiscreteCoordinate
* @param mcoord discrete coordinates
* @return reference to this element
*/
template <class... ODDims, class = std::enable_if_t<sizeof...(ODDims) != 1>>
element_type& operator()(
detail::TaggedVector<DiscreteCoordElement, ODDims...> const& mcoord) noexcept
element_type& operator()(DiscreteCoordinate<ODDims...> const& mcoord) noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((get<ODDims>(mcoord) >= front<ODDims>(this->m_domain)) && ...));
assert(((get<ODDims>(mcoord) <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(get<DDims>(mcoord)...);
assert(((select<ODDims>(mcoord) >= front<ODDims>(this->m_domain)) && ...));
assert(((select<ODDims>(mcoord) <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid<DDims>(mcoord)...);
}

/** Access to the underlying allocation pointer
Expand Down
18 changes: 9 additions & 9 deletions include/ddc/chunk_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,18 +161,18 @@ class ChunkCommon<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy>
return m_internal_mdspan.accessor();
}

constexpr mcoord_type extents() const noexcept
constexpr DiscreteVector<DDims...> extents() const noexcept
{
return mcoord_type(
return DiscreteVector<DDims...>(
(m_internal_mdspan.extent(type_seq_rank_v<DDims, detail::TypeSeq<DDims...>>)
- front<DDims>(m_domain))...);
- front<DDims>(m_domain).uid())...);
}

template <class QueryDDim>
constexpr size_type extent() const noexcept
{
return m_internal_mdspan.extent(type_seq_rank_v<QueryDDim, detail::TypeSeq<DDims...>>)
- front<QueryDDim>(m_domain);
- front<QueryDDim>(m_domain).uid();
}

constexpr size_type size() const noexcept
Expand Down Expand Up @@ -249,16 +249,16 @@ class ChunkCommon<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy>
namespace stdex = std::experimental;
assert(ptr != nullptr);

extents_type extents_r(::extents<DDims>(domain)...);
extents_type extents_r(::extents<DDims>(domain).value()...);
mapping_type mapping_r(extents_r);

extents_type extents_s((front<DDims>(domain) + ::extents<DDims>(domain))...);
extents_type extents_s((front<DDims>(domain) + ::extents<DDims>(domain)).uid()...);
std::array<std::size_t, sizeof...(DDims)> strides_s {
mapping_r.stride(type_seq_rank_v<DDims, detail::TypeSeq<DDims...>>)...};
stdex::layout_stride::mapping<extents_type> mapping_s(extents_s, strides_s);

// Pointer offset to handle non-zero indexing
ptr -= mapping_s(front<DDims>(domain)...);
ptr -= mapping_s(front<DDims>(domain).uid()...);
m_internal_mdspan = internal_mdspan_type(ptr, mapping_s);
m_domain = domain;
}
Expand Down Expand Up @@ -290,7 +290,7 @@ class ChunkCommon<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy>
*/
constexpr ElementType* data() const
{
return &m_internal_mdspan(front<DDims>(m_domain)...);
return &m_internal_mdspan(front<DDims>(m_domain).uid()...);
}

/** Provide a modifiable view of the data
Expand All @@ -306,7 +306,7 @@ class ChunkCommon<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy>
*/
constexpr allocation_mdspan_type allocation_mdspan() const
{
extents_type extents_s(::extents<DDims>(m_domain)...);
extents_type extents_s(::extents<DDims>(m_domain).value()...);
if constexpr (std::is_same_v<LayoutStridedPolicy, std::experimental::layout_stride>) {
mapping_type map(extents_s, m_internal_mdspan.mapping().strides());
return allocation_mdspan_type(data(), map);
Expand Down
22 changes: 10 additions & 12 deletions include/ddc/chunk_span.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <experimental/mdspan>

#include "ddc/chunk_common.hpp"
#include "ddc/discrete_coordinate.hpp"
#include "ddc/discrete_domain.hpp"

template <class, class, class>
Expand Down Expand Up @@ -84,7 +85,7 @@ class ChunkSpan<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy>
auto get_slicer_for(DiscreteCoordinate<ODDims...> const& c) const
{
if constexpr (in_tags_v<QueryDDim, detail::TypeSeq<ODDims...>>) {
return get<QueryDDim>(c) - front<QueryDDim>(this->m_domain);
return (uid<QueryDDim>(c) - front<QueryDDim>(this->m_domain).uid());
} else {
return std::experimental::full_extent;
}
Expand Down Expand Up @@ -166,12 +167,12 @@ class ChunkSpan<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy>
constexpr ChunkSpan(allocation_mdspan_type allocation_mdspan, mdomain_type const& domain)
{
namespace stdex = std::experimental;
extents_type extents_s((front<DDims>(domain) + ::extents<DDims>(domain))...);
extents_type extents_s((front<DDims>(domain) + ::extents<DDims>(domain)).uid()...);
std::array<std::size_t, sizeof...(DDims)> strides_s {allocation_mdspan.mapping().stride(
type_seq_rank_v<DDims, detail::TypeSeq<DDims...>>)...};
stdex::layout_stride::mapping<extents_type> mapping_s(extents_s, strides_s);
this->m_internal_mdspan = internal_mdspan_type(
allocation_mdspan.data() - mapping_s(front<DDims>(domain)...),
allocation_mdspan.data() - mapping_s(front<DDims>(domain).uid()...),
mapping_s);
this->m_domain = domain;
}
Expand Down Expand Up @@ -214,29 +215,26 @@ class ChunkSpan<ElementType, DiscreteDomain<DDims...>, LayoutStridedPolicy>
* @param mcoords 1D discrete coordinates
* @return reference to this element
*/
// Warning: Do not use DiscreteCoordinate because of template deduction issue with clang 12
template <class... ODDims>
constexpr reference operator()(
detail::TaggedVector<DiscreteCoordElement, ODDims> const&... mcoords) const noexcept
constexpr reference operator()(DiscreteCoordinate<ODDims> const&... mcoords) const noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((mcoords >= front<ODDims>(this->m_domain)) && ...));
assert(((mcoords <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(take<DDims>(mcoords...)...);
return this->m_internal_mdspan(uid(take<DDims>(mcoords...))...);
}

/** Element access using a multi-dimensional DiscreteCoordinate
* @param mcoord discrete coordinates
* @return reference to this element
*/
template <class... ODDims, class = std::enable_if_t<sizeof...(ODDims) != 1>>
constexpr reference operator()(
detail::TaggedVector<DiscreteCoordElement, ODDims...> const& mcoord) const noexcept
constexpr reference operator()(DiscreteCoordinate<ODDims...> const& mcoord) const noexcept
{
static_assert(sizeof...(ODDims) == sizeof...(DDims), "Invalid number of dimensions");
assert(((get<ODDims>(mcoord) >= front<ODDims>(this->m_domain)) && ...));
assert(((get<ODDims>(mcoord) <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(get<DDims>(mcoord)...);
assert(((select<ODDims>(mcoord) >= front<ODDims>(this->m_domain)) && ...));
assert(((select<ODDims>(mcoord) <= back<ODDims>(this->m_domain)) && ...));
return this->m_internal_mdspan(uid<DDims>(mcoord)...);
}

/** Access to the underlying allocation pointer
Expand Down
1 change: 1 addition & 0 deletions include/ddc/ddc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// Discretizations
#include "ddc/discrete_coordinate.hpp"
#include "ddc/discrete_domain.hpp"
#include "ddc/discrete_vector.hpp"
#include "ddc/discretization.hpp"
#include "ddc/non_uniform_discretization.hpp"
#include "ddc/rectilinear_domain.hpp"
Expand Down
78 changes: 61 additions & 17 deletions include/ddc/detail/tagged_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,24 @@ inline constexpr ElementType const& get_or(
return tuple.template get_or<QueryTag>(default_value);
}

/// Unary operators: +, -

template <class ElementType, class... Tags, class OElementType, class... OTags>
constexpr inline detail::TaggedVector<ElementType, Tags...> operator+(
detail::TaggedVector<ElementType, Tags...> const& x)
{
return x;
}

template <class ElementType, class... Tags, class OElementType, class... OTags>
constexpr inline detail::TaggedVector<ElementType, Tags...> operator-(
detail::TaggedVector<ElementType, Tags...> const& x)
{
return detail::TaggedVector<ElementType, Tags...>((-get<Tags>(x))...);
}

/// Internal binary operators: +, -

template <class ElementType, class... Tags, class OElementType, class... OTags>
constexpr inline auto operator+(
detail::TaggedVector<ElementType, Tags...> const& lhs,
Expand All @@ -64,22 +82,32 @@ constexpr inline auto operator+(
return detail::TaggedVector<RElementType, Tags...>((get<Tags>(lhs) + get<Tags>(rhs))...);
}

template <class ElementType, class... Tags, class OElementType>
template <
class ElementType,
class Tag,
class OElementType,
class = std::enable_if_t<!detail::is_tagged_vector_v<OElementType>>,
class = std::enable_if_t<std::is_convertible_v<OElementType, ElementType>>>
constexpr inline auto operator+(
detail::TaggedVector<ElementType, Tags...> const& lhs,
detail::TaggedVector<ElementType, Tag> const& lhs,
OElementType const& rhs)
{
using RElementType = decltype(std::declval<ElementType>() + std::declval<OElementType>());
return detail::TaggedVector<RElementType, Tags...>((get<Tags>(lhs) + rhs)...);
return detail::TaggedVector<RElementType, Tag>(get<Tag>(lhs) + rhs);
}

template <class ElementType, class... Tags, class OElementType>
template <
class ElementType,
class Tag,
class OElementType,
class = std::enable_if_t<!detail::is_tagged_vector_v<OElementType>>,
class = std::enable_if_t<std::is_convertible_v<ElementType, OElementType>>>
constexpr inline auto operator+(
OElementType const& lhs,
detail::TaggedVector<ElementType, Tags...> const& rhs)
detail::TaggedVector<ElementType, Tag> const& rhs)
{
using RElementType = decltype(std::declval<ElementType>() + std::declval<OElementType>());
return detail::TaggedVector<RElementType, Tags...>((lhs + get<Tags>(rhs))...);
return detail::TaggedVector<RElementType, Tag>(lhs + get<Tag>(rhs));
}

template <class ElementType, class... Tags, class OElementType, class... OTags>
Expand All @@ -92,32 +120,48 @@ constexpr inline auto operator-(
return detail::TaggedVector<RElementType, Tags...>((get<Tags>(lhs) - get<Tags>(rhs))...);
}

template <class ElementType, class... Tags, class OElementType>
template <
class ElementType,
class Tag,
class OElementType,
class = std::enable_if_t<!detail::is_tagged_vector_v<OElementType>>,
class = std::enable_if_t<std::is_convertible_v<OElementType, ElementType>>>
constexpr inline auto operator-(
detail::TaggedVector<ElementType, Tags...> const& lhs,
detail::TaggedVector<ElementType, Tag> const& lhs,
OElementType const& rhs)
{
using RElementType = decltype(std::declval<ElementType>() + std::declval<OElementType>());
return detail::TaggedVector<RElementType, Tags...>((get<Tags>(lhs) - rhs)...);
return detail::TaggedVector<RElementType, Tag>(get<Tag>(lhs) - rhs);
}

template <class ElementType, class... Tags, class OElementType>
template <
class ElementType,
class Tag,
class OElementType,
class = std::enable_if_t<!detail::is_tagged_vector_v<OElementType>>,
class = std::enable_if_t<std::is_convertible_v<ElementType, OElementType>>>
constexpr inline auto operator-(
OElementType const& lhs,
detail::TaggedVector<ElementType, Tags...> const& rhs)
detail::TaggedVector<ElementType, Tag> const& rhs)
{
using RElementType = decltype(std::declval<ElementType>() + std::declval<OElementType>());
return detail::TaggedVector<RElementType, Tags...>((lhs - get<Tags>(rhs))...);
return detail::TaggedVector<RElementType, Tag>(lhs - get<Tag>(rhs));
}

template <class ElementType, class... Tags, class OElementType, class... OTags>
/// external left binary operator: *

template <
class ElementType,
class OElementType,
class... Tags,
class = std::enable_if_t<!detail::is_tagged_vector_v<OElementType>>,
class = std::enable_if_t<std::is_convertible_v<ElementType, OElementType>>>
constexpr inline auto operator*(
detail::TaggedVector<ElementType, Tags...> const& lhs,
detail::TaggedVector<OElementType, OTags...> const& rhs)
ElementType const& lhs,
detail::TaggedVector<OElementType, Tags...> const& rhs)
{
static_assert(type_seq_same_v<detail::TypeSeq<Tags...>, detail::TypeSeq<OTags...>>);
using RElementType = decltype(std::declval<ElementType>() * std::declval<OElementType>());
return detail::TaggedVector<RElementType, Tags...>((get<Tags>(lhs) * get<Tags>(rhs))...);
return detail::TaggedVector<RElementType, Tags...>((lhs * get<Tags>(rhs))...);
}

template <class... QueryTags, class ElementType, class... Tags>
Expand Down
Loading