Skip to content

Commit

Permalink
Optimize multishift for gcc9 (#1630)
Browse files Browse the repository at this point in the history
  • Loading branch information
anstaf authored Apr 7, 2021
1 parent 4978ca8 commit 2072fd2
Show file tree
Hide file tree
Showing 7 changed files with 45,354 additions and 42,751 deletions.
53 changes: 44 additions & 9 deletions include/gridtools/sid/multi_shift.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,61 @@
#include <type_traits>

#include "../common/defs.hpp"
#include "../common/for_each.hpp"
#include "../common/host_device.hpp"
#include "../common/hymap.hpp"
#include "concept.hpp"

namespace gridtools {
namespace sid {
namespace multi_shift_impl_ {
template <class Dims>
struct for_each_dim;

template <template <class...> class L, class... Dims>
struct for_each_dim<L<Dims...>> {
template <class Ptr, class Strides, class Offsets>
GT_FUNCTION for_each_dim(Ptr &ptr, Strides const &strides, Offsets offsets) {
using array_t = int[sizeof...(Dims)];
(void)array_t{
(shift(ptr, get_stride<Dims>(strides), gridtools::host_device::at_key<Dims>(offsets)), 0)...};
}
};

template <template <class...> class L>
struct for_each_dim<L<>> {
template <class Ptr, class Strides, class Offsets>
GT_FUNCTION for_each_dim(Ptr &ptr, Strides const &strides, Offsets offsets) {}
};

template <class Arg, class Dims>
struct for_each_dim_a;

template <class Arg, template <class...> class L, class... Dims>
struct for_each_dim_a<Arg, L<Dims...>> {
template <class Ptr, class Strides, class Offsets>
GT_FUNCTION for_each_dim_a(Ptr &ptr, Strides const &strides, Offsets offsets) {
using array_t = int[sizeof...(Dims)];
(void)array_t{(
shift(
ptr, get_stride_element<Arg, Dims>(strides), gridtools::host_device::at_key<Dims>(offsets)),
0)...};
}
};

template <class Arg, template <class...> class L>
struct for_each_dim_a<Arg, L<>> {
template <class Ptr, class Strides, class Offsets>
GT_FUNCTION for_each_dim_a(Ptr &ptr, Strides const &strides, Offsets offsets) {}
};
} // namespace multi_shift_impl_

/**
* A helper the invokes `sid::shift` in several dimensions.
* `offsets` should be a hymap of individual offsets.
*/
template <class Ptr, class Strides, class Offsets>
GT_FUNCTION void multi_shift(Ptr &ptr, Strides const &strides, Offsets offsets) {
gridtools::host_device::for_each<meta::transform<meta::lazy::id, get_keys<Offsets>>>([&](auto key) {
using key_t = typename decltype(key)::type;
shift(ptr, get_stride<key_t>(strides), gridtools::host_device::at_key<key_t>(offsets));
});
multi_shift_impl_::for_each_dim<get_keys<Offsets>>(ptr, strides, wstd::move(offsets));
}

template <class Ptr, class Strides, class Offsets>
Expand All @@ -42,10 +80,7 @@ namespace gridtools {
*/
template <class Arg, class Ptr, class Strides, class Offsets>
GT_FUNCTION void multi_shift(Ptr &ptr, Strides const &strides, Offsets offsets) {
gridtools::host_device::for_each<meta::transform<meta::lazy::id, get_keys<Offsets>>>([&](auto key) {
using key_t = typename decltype(key)::type;
shift(ptr, get_stride_element<Arg, key_t>(strides), gridtools::host_device::at_key<key_t>(offsets));
});
multi_shift_impl_::for_each_dim_a<Arg, get_keys<Offsets>>(ptr, strides, wstd::move(offsets));
}

template <class Arg, class Ptr, class Strides, class Offsets>
Expand Down
Loading

0 comments on commit 2072fd2

Please sign in to comment.