Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update strided_index_range to strided_slice #231

Merged
merged 4 commits into from
Jan 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/experimental/__p0009_bits/layout_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct layout_stride {

template<class SizeType, ::std::size_t ... Ep, ::std::size_t ... Idx>
_MDSPAN_HOST_DEVICE
constexpr index_type __get_size(extents<SizeType, Ep...>,integer_sequence<::std::size_t, Idx...>) const {
constexpr index_type __get_size(::std::experimental::extents<SizeType, Ep...>,integer_sequence<::std::size_t, Idx...>) const {
return _MDSPAN_FOLD_TIMES_RIGHT( static_cast<index_type>(extents().extent(Idx)), 1 );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,32 @@
//
//@HEADER

#include <type_traits>

namespace std {
namespace experimental {

namespace {
template<class T>
struct __mdspan_is_integral_constant: std::false_type {};

template<class T, T val>
struct __mdspan_is_integral_constant<integral_constant<T,val>>: std::true_type {};
}
// Slice Specifier allowing for strides and compile time extent
template <class OffsetType, class ExtentType, class StrideType>
struct strided_index_range {
struct strided_slice {
using offset_type = OffsetType;
using extent_type = ExtentType;
using stride_type = StrideType;

OffsetType offset;
ExtentType extent;
StrideType stride;

static_assert(is_integral_v<OffsetType> || __mdspan_is_integral_constant<OffsetType>::value);
static_assert(is_integral_v<ExtentType> || __mdspan_is_integral_constant<ExtentType>::value);
static_assert(is_integral_v<StrideType> || __mdspan_is_integral_constant<StrideType>::value);
};

} // experimental
Expand Down
24 changes: 12 additions & 12 deletions include/experimental/__p2630_bits/submdspan_extents.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
//
//@HEADER

#include "strided_index_range.hpp"
#include "strided_slice.hpp"
namespace std {
namespace experimental {
namespace detail {
Expand Down Expand Up @@ -42,12 +42,12 @@ constexpr auto inv_map_rank(integral_constant<size_t, Counter>, index_sequence<M
slices...);
}

// Helper for identifying strided_index_range
template <class T> struct is_strided_index_range : false_type {};
// Helper for identifying strided_slice
template <class T> struct is_strided_slice : false_type {};

template <class OffsetType, class ExtentType, class StrideType>
struct is_strided_index_range<
strided_index_range<OffsetType, ExtentType, StrideType>> : true_type {};
struct is_strided_slice<
strided_slice<OffsetType, ExtentType, StrideType>> : true_type {};

// first_of(slice): getting begin of slice specifier range
MDSPAN_TEMPLATE_REQUIRES(
Expand Down Expand Up @@ -77,7 +77,7 @@ constexpr auto first_of(const Slice &i) {
template <class OffsetType, class ExtentType, class StrideType>
MDSPAN_INLINE_FUNCTION
constexpr OffsetType
first_of(const strided_index_range<OffsetType, ExtentType, StrideType> &r) {
first_of(const strided_slice<OffsetType, ExtentType, StrideType> &r) {
return r.offset;
}

Expand Down Expand Up @@ -155,7 +155,7 @@ template <size_t k, class Extents, class OffsetType, class ExtentType,
MDSPAN_INLINE_FUNCTION
constexpr OffsetType
last_of(integral_constant<size_t, k>, const Extents &,
const strided_index_range<OffsetType, ExtentType, StrideType> &r) {
const strided_slice<OffsetType, ExtentType, StrideType> &r) {
return r.extent;
}

Expand All @@ -169,7 +169,7 @@ constexpr auto stride_of(const T &) {
template <class OffsetType, class ExtentType, class StrideType>
MDSPAN_INLINE_FUNCTION
constexpr auto
stride_of(const strided_index_range<OffsetType, ExtentType, StrideType> &r) {
stride_of(const strided_slice<OffsetType, ExtentType, StrideType> &r) {
return r.stride;
}

Expand All @@ -185,7 +185,7 @@ MDSPAN_INLINE_FUNCTION
constexpr auto divide(const integral_constant<T0, v0> &,
const integral_constant<T1, v1> &) {
// cutting short division by zero
// this is used for strided_index_range with zero extent/stride
// this is used for strided_slice with zero extent/stride
return integral_constant<IndexT, v0 == 0 ? 0 : v0 / v1>();
}

Expand Down Expand Up @@ -214,7 +214,7 @@ struct StaticExtentFromRange<std::integral_constant<Integral0, val0>,
constexpr static size_t value = val1 - val0;
};

// compute new static extent from strided_index_range, preserving static
// compute new static extent from strided_slice, preserving static
// knowledge
template <class Arg0, class Arg1> struct StaticExtentFromStridedRange {
constexpr static size_t value = dynamic_extent;
Expand All @@ -233,7 +233,7 @@ struct extents_constructor {
MDSPAN_TEMPLATE_REQUIRES(
class Slice, class... SlicesAndExtents,
/* requires */(!is_convertible_v<Slice, size_t> &&
!is_strided_index_range<Slice>::value)
!is_strided_slice<Slice>::value)
)
MDSPAN_INLINE_FUNCTION
constexpr static auto next_extent(const Extents &ext, const Slice &sl,
Expand Down Expand Up @@ -270,7 +270,7 @@ struct extents_constructor {
MDSPAN_INLINE_FUNCTION
constexpr static auto
next_extent(const Extents &ext,
const strided_index_range<OffsetType, ExtentType, StrideType> &r,
const strided_slice<OffsetType, ExtentType, StrideType> &r,
SlicesAndExtents... slices_and_extents) {
using index_t = typename Extents::index_type;
using new_static_extent_t =
Expand Down
38 changes: 19 additions & 19 deletions tests/test_submdspan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,22 @@ using submdspan_test_types =
, std::tuple<stdex::layout_right, stdex::layout_right, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,dyn,8>, int, int, int, int, std::pair<int,int>, stdex::full_extent_t>
, std::tuple<stdex::layout_right, stdex::layout_right, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,8>, int, int, int, int, int, stdex::full_extent_t>
// LayoutLeft to LayoutStride
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>, int>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>, int>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::full_extent_t>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_index_range<int,int,int>, std::pair<int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_index_range<int,int,int>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_slice<int,int,int>, std::pair<int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_slice<int,int,int>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,6,dyn,8>, stdex::full_extent_t, int, std::pair<int,int>, int, int, stdex::full_extent_t>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,4,dyn,7>, int, stdex::full_extent_t, std::pair<int,int>, int, stdex::full_extent_t, int>
// layout_right to layout_stride
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::extents<size_t,0>, stdex::strided_index_range<int,std::integral_constant<int,0>,std::integral_constant<int,0>>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>, int>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::extents<size_t,0>, stdex::strided_slice<int,std::integral_constant<int,0>,std::integral_constant<int,0>>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>, int>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::full_extent_t, std::pair<int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_index_range<int,int,int>, std::pair<int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_index_range<int,int,int>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_slice<int,int,int>, std::pair<int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_slice<int,int,int>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,6,dyn,8>, stdex::full_extent_t, int, std::pair<int,int>, int, int, stdex::full_extent_t>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,4,dyn,7>, int, stdex::full_extent_t, std::pair<int,int>, int, stdex::full_extent_t, int>
// Testing of customization point design
Expand All @@ -172,8 +172,8 @@ using submdspan_test_types =
, std::tuple<Foo::layout_foo, Foo::layout_foo, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::full_extent_t, stdex::full_extent_t>
, std::tuple<Foo::layout_foo, Foo::layout_foo, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::full_extent_t>
, std::tuple<Foo::layout_foo, Foo::layout_foo, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, int, stdex::full_extent_t>
, std::tuple<Foo::layout_foo, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>>
, std::tuple<Foo::layout_foo, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>, int>
, std::tuple<Foo::layout_foo, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>>
, std::tuple<Foo::layout_foo, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>, int>
>;

template<class T> struct TestSubMDSpan;
Expand Down Expand Up @@ -204,13 +204,13 @@ struct TestSubMDSpan<
return std::pair<int,int>(1,3);
}
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(stdex::strided_index_range<int,int,int>) {
return stdex::strided_index_range<int,int,int>{1,3,2};
static auto create_slice_arg(stdex::strided_slice<int,int,int>) {
return stdex::strided_slice<int,int,int>{1,3,2};
}
template<int Ext, int Stride>
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(stdex::strided_index_range<int,std::integral_constant<int, Ext>, std::integral_constant<int, Stride>>) {
return stdex::strided_index_range<int,std::integral_constant<int, Ext>, std::integral_constant<int, Stride>>{1,std::integral_constant<int, Ext>(), std::integral_constant<int, Ext>()};
static auto create_slice_arg(stdex::strided_slice<int,std::integral_constant<int, Ext>, std::integral_constant<int, Stride>>) {
return stdex::strided_slice<int,std::integral_constant<int, Ext>, std::integral_constant<int, Stride>>{1,std::integral_constant<int, Ext>(), std::integral_constant<int, Ext>()};
}
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(stdex::full_extent_t) {
Expand All @@ -231,14 +231,14 @@ struct TestSubMDSpan<
template<class SrcExtents, class SubExtents, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext,
stdex::strided_index_range<int,int,int> p, SliceArgs ... slices) {
stdex::strided_slice<int,int,int> p, SliceArgs ... slices) {
using idx_t = typename SubExtents::index_type;
return (sub_ext.extent(sub_idx)==static_cast<idx_t>((p.extent+p.stride-1)/p.stride)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
}
template<class SrcExtents, class SubExtents, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext,
stdex::strided_index_range<int,std::integral_constant<int, 0>,std::integral_constant<int,0>>, SliceArgs ... slices) {
stdex::strided_slice<int,std::integral_constant<int, 0>,std::integral_constant<int,0>>, SliceArgs ... slices) {
return (sub_ext.extent(sub_idx)==0) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
}
template<class SrcExtents, class SubExtents, class ... SliceArgs>
Expand Down