Skip to content

Commit

Permalink
Merge pull request #231 from crtrott/update-strided-slice
Browse files Browse the repository at this point in the history
Update strided_index_range to strided_slice
  • Loading branch information
crtrott authored Jan 25, 2023
2 parents 5bbee18 + 5702ddc commit 9ce9099
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 33 deletions.
2 changes: 1 addition & 1 deletion include/experimental/__p0009_bits/layout_stride.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct layout_stride {

template<class SizeType, ::std::size_t ... Ep, ::std::size_t ... Idx>
_MDSPAN_HOST_DEVICE
constexpr index_type __get_size(extents<SizeType, Ep...>,integer_sequence<::std::size_t, Idx...>) const {
constexpr index_type __get_size(::std::experimental::extents<SizeType, Ep...>,integer_sequence<::std::size_t, Idx...>) const {
return _MDSPAN_FOLD_TIMES_RIGHT( static_cast<index_type>(extents().extent(Idx)), 1 );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,32 @@
//
//@HEADER

#include <type_traits>

namespace std {
namespace experimental {

namespace {
template<class T>
struct __mdspan_is_integral_constant: std::false_type {};

template<class T, T val>
struct __mdspan_is_integral_constant<integral_constant<T,val>>: std::true_type {};
}
// Slice Specifier allowing for strides and compile time extent
template <class OffsetType, class ExtentType, class StrideType>
struct strided_index_range {
struct strided_slice {
using offset_type = OffsetType;
using extent_type = ExtentType;
using stride_type = StrideType;

OffsetType offset;
ExtentType extent;
StrideType stride;

static_assert(is_integral_v<OffsetType> || __mdspan_is_integral_constant<OffsetType>::value);
static_assert(is_integral_v<ExtentType> || __mdspan_is_integral_constant<ExtentType>::value);
static_assert(is_integral_v<StrideType> || __mdspan_is_integral_constant<StrideType>::value);
};

} // experimental
Expand Down
24 changes: 12 additions & 12 deletions include/experimental/__p2630_bits/submdspan_extents.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
//
//@HEADER

#include "strided_index_range.hpp"
#include "strided_slice.hpp"
namespace std {
namespace experimental {
namespace detail {
Expand Down Expand Up @@ -42,12 +42,12 @@ constexpr auto inv_map_rank(integral_constant<size_t, Counter>, index_sequence<M
slices...);
}

// Helper for identifying strided_index_range
template <class T> struct is_strided_index_range : false_type {};
// Helper for identifying strided_slice
template <class T> struct is_strided_slice : false_type {};

template <class OffsetType, class ExtentType, class StrideType>
struct is_strided_index_range<
strided_index_range<OffsetType, ExtentType, StrideType>> : true_type {};
struct is_strided_slice<
strided_slice<OffsetType, ExtentType, StrideType>> : true_type {};

// first_of(slice): getting begin of slice specifier range
MDSPAN_TEMPLATE_REQUIRES(
Expand Down Expand Up @@ -77,7 +77,7 @@ constexpr auto first_of(const Slice &i) {
template <class OffsetType, class ExtentType, class StrideType>
MDSPAN_INLINE_FUNCTION
constexpr OffsetType
first_of(const strided_index_range<OffsetType, ExtentType, StrideType> &r) {
first_of(const strided_slice<OffsetType, ExtentType, StrideType> &r) {
return r.offset;
}

Expand Down Expand Up @@ -155,7 +155,7 @@ template <size_t k, class Extents, class OffsetType, class ExtentType,
MDSPAN_INLINE_FUNCTION
constexpr OffsetType
last_of(integral_constant<size_t, k>, const Extents &,
const strided_index_range<OffsetType, ExtentType, StrideType> &r) {
const strided_slice<OffsetType, ExtentType, StrideType> &r) {
return r.extent;
}

Expand All @@ -169,7 +169,7 @@ constexpr auto stride_of(const T &) {
template <class OffsetType, class ExtentType, class StrideType>
MDSPAN_INLINE_FUNCTION
constexpr auto
stride_of(const strided_index_range<OffsetType, ExtentType, StrideType> &r) {
stride_of(const strided_slice<OffsetType, ExtentType, StrideType> &r) {
return r.stride;
}

Expand All @@ -185,7 +185,7 @@ MDSPAN_INLINE_FUNCTION
constexpr auto divide(const integral_constant<T0, v0> &,
const integral_constant<T1, v1> &) {
// cutting short division by zero
// this is used for strided_index_range with zero extent/stride
// this is used for strided_slice with zero extent/stride
return integral_constant<IndexT, v0 == 0 ? 0 : v0 / v1>();
}

Expand Down Expand Up @@ -214,7 +214,7 @@ struct StaticExtentFromRange<std::integral_constant<Integral0, val0>,
constexpr static size_t value = val1 - val0;
};

// compute new static extent from strided_index_range, preserving static
// compute new static extent from strided_slice, preserving static
// knowledge
template <class Arg0, class Arg1> struct StaticExtentFromStridedRange {
constexpr static size_t value = dynamic_extent;
Expand All @@ -233,7 +233,7 @@ struct extents_constructor {
MDSPAN_TEMPLATE_REQUIRES(
class Slice, class... SlicesAndExtents,
/* requires */(!is_convertible_v<Slice, size_t> &&
!is_strided_index_range<Slice>::value)
!is_strided_slice<Slice>::value)
)
MDSPAN_INLINE_FUNCTION
constexpr static auto next_extent(const Extents &ext, const Slice &sl,
Expand Down Expand Up @@ -270,7 +270,7 @@ struct extents_constructor {
MDSPAN_INLINE_FUNCTION
constexpr static auto
next_extent(const Extents &ext,
const strided_index_range<OffsetType, ExtentType, StrideType> &r,
const strided_slice<OffsetType, ExtentType, StrideType> &r,
SlicesAndExtents... slices_and_extents) {
using index_t = typename Extents::index_type;
using new_static_extent_t =
Expand Down
38 changes: 19 additions & 19 deletions tests/test_submdspan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,22 @@ using submdspan_test_types =
, std::tuple<stdex::layout_right, stdex::layout_right, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,dyn,8>, int, int, int, int, std::pair<int,int>, stdex::full_extent_t>
, std::tuple<stdex::layout_right, stdex::layout_right, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,8>, int, int, int, int, int, stdex::full_extent_t>
// LayoutLeft to LayoutStride
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>, int>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>, int>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::full_extent_t>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_index_range<int,int,int>, std::pair<int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_index_range<int,int,int>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_slice<int,int,int>, std::pair<int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_slice<int,int,int>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,6,dyn,8>, stdex::full_extent_t, int, std::pair<int,int>, int, int, stdex::full_extent_t>
, std::tuple<stdex::layout_left, stdex::layout_stride, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,4,dyn,7>, int, stdex::full_extent_t, std::pair<int,int>, int, stdex::full_extent_t, int>
// layout_right to layout_stride
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::extents<size_t,0>, stdex::strided_index_range<int,std::integral_constant<int,0>,std::integral_constant<int,0>>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>, int>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::extents<size_t,0>, stdex::strided_slice<int,std::integral_constant<int,0>,std::integral_constant<int,0>>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>, int>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::full_extent_t, std::pair<int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_index_range<int,int,int>, std::pair<int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_index_range<int,int,int>, stdex::strided_index_range<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_slice<int,int,int>, std::pair<int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::strided_slice<int,int,int>, stdex::strided_slice<int,int,int>>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,6,dyn,8>, stdex::full_extent_t, int, std::pair<int,int>, int, int, stdex::full_extent_t>
, std::tuple<stdex::layout_right, stdex::layout_stride, stdex::extents<size_t,6,4,5,6,7,8>, args_t<6,4,5,6,7,8>, stdex::extents<size_t,4,dyn,7>, int, stdex::full_extent_t, std::pair<int,int>, int, stdex::full_extent_t, int>
// Testing of customization point design
Expand All @@ -172,8 +172,8 @@ using submdspan_test_types =
, std::tuple<Foo::layout_foo, Foo::layout_foo, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, stdex::full_extent_t, stdex::full_extent_t>
, std::tuple<Foo::layout_foo, Foo::layout_foo, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,2>, std::pair<int,int>, stdex::full_extent_t>
, std::tuple<Foo::layout_foo, Foo::layout_foo, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, int, stdex::full_extent_t>
, std::tuple<Foo::layout_foo, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>>
, std::tuple<Foo::layout_foo, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_index_range<int,int,int>, int>
, std::tuple<Foo::layout_foo, stdex::layout_stride, stdex::dextents<size_t,1>, args_t<10>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>>
, std::tuple<Foo::layout_foo, stdex::layout_stride, stdex::dextents<size_t,2>, args_t<10,20>, stdex::dextents<size_t,1>, stdex::strided_slice<int,int,int>, int>
>;

template<class T> struct TestSubMDSpan;
Expand Down Expand Up @@ -204,13 +204,13 @@ struct TestSubMDSpan<
return std::pair<int,int>(1,3);
}
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(stdex::strided_index_range<int,int,int>) {
return stdex::strided_index_range<int,int,int>{1,3,2};
static auto create_slice_arg(stdex::strided_slice<int,int,int>) {
return stdex::strided_slice<int,int,int>{1,3,2};
}
template<int Ext, int Stride>
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(stdex::strided_index_range<int,std::integral_constant<int, Ext>, std::integral_constant<int, Stride>>) {
return stdex::strided_index_range<int,std::integral_constant<int, Ext>, std::integral_constant<int, Stride>>{1,std::integral_constant<int, Ext>(), std::integral_constant<int, Ext>()};
static auto create_slice_arg(stdex::strided_slice<int,std::integral_constant<int, Ext>, std::integral_constant<int, Stride>>) {
return stdex::strided_slice<int,std::integral_constant<int, Ext>, std::integral_constant<int, Stride>>{1,std::integral_constant<int, Ext>(), std::integral_constant<int, Ext>()};
}
MDSPAN_INLINE_FUNCTION
static auto create_slice_arg(stdex::full_extent_t) {
Expand All @@ -231,14 +231,14 @@ struct TestSubMDSpan<
template<class SrcExtents, class SubExtents, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext,
stdex::strided_index_range<int,int,int> p, SliceArgs ... slices) {
stdex::strided_slice<int,int,int> p, SliceArgs ... slices) {
using idx_t = typename SubExtents::index_type;
return (sub_ext.extent(sub_idx)==static_cast<idx_t>((p.extent+p.stride-1)/p.stride)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
}
template<class SrcExtents, class SubExtents, class ... SliceArgs>
MDSPAN_INLINE_FUNCTION
static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext,
stdex::strided_index_range<int,std::integral_constant<int, 0>,std::integral_constant<int,0>>, SliceArgs ... slices) {
stdex::strided_slice<int,std::integral_constant<int, 0>,std::integral_constant<int,0>>, SliceArgs ... slices) {
return (sub_ext.extent(sub_idx)==0) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...);
}
template<class SrcExtents, class SubExtents, class ... SliceArgs>
Expand Down

0 comments on commit 9ce9099

Please sign in to comment.