Skip to content

Commit

Permalink
Add Const to SID Neighbor Table Element Type (#1808)
Browse files Browse the repository at this point in the history
Allows NVCC to use `LDG` instructions and thus perform better
double-load elision. Credit goes to @iomaganaris. Further, fixes
`sid::as_const` for C arrays.

---------

Co-authored-by: Ioannis Magkanaris <ioannis.magkanaris@cscs.ch>
  • Loading branch information
fthaler and iomaganaris authored Oct 30, 2024
1 parent 049d6fe commit 85de9ff
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
10 changes: 6 additions & 4 deletions include/gridtools/fn/sid_neighbor_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "../common/array.hpp"
#include "../common/const_ptr_deref.hpp"
#include "../fn/unstructured.hpp"
#include "../sid/as_const.hpp"
#include "../sid/concept.hpp"

namespace gridtools::fn::sid_neighbor_table {
Expand Down Expand Up @@ -61,14 +62,15 @@ namespace gridtools::fn::sid_neighbor_table {
static_assert(!std::is_same_v<IndexDimension, NeighborDimension>,
"The index dimension and the neighbor dimension must be different.");

const auto origin = sid::get_origin(sid);
const auto strides = sid::get_strides(sid);
decltype(auto) const_sid = sid::as_const(std::forward<Sid>(sid));
const auto origin = sid::get_origin(const_sid);
const auto strides = sid::get_strides(const_sid);

return sid_neighbor_table<IndexDimension,
NeighborDimension,
MaxNumNeighbors,
sid::ptr_holder_type<Sid>,
sid::strides_type<Sid>>{
decltype(origin),
decltype(strides)>{
origin, strides}; // Note: putting the return type into the function signature will crash nvcc 12.0
}
} // namespace sid_neighbor_table_impl_
Expand Down
4 changes: 2 additions & 2 deletions include/gridtools/sid/as_const.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ namespace gridtools {
* probably might we need the `host` and `device` variations as well
*/
template <class Src,
class Ptr = sid::ptr_type<std::decay_t<Src>>,
class Ptr = sid::ptr_type<std::remove_cv_t<std::remove_reference_t<Src>>>,
std::enable_if_t<std::is_pointer_v<Ptr> && !std::is_const_v<std::remove_pointer_t<Ptr>>, int> = 0>
as_const_impl_::const_adapter<Src> as_const(Src &&src) {
return {std::forward<Src>(src)};
}

template <class Src,
class Ptr = sid::ptr_type<std::decay_t<Src>>,
class Ptr = sid::ptr_type<std::remove_cv_t<std::remove_reference_t<Src>>>,
std::enable_if_t<!std::is_pointer_v<Ptr> || std::is_const_v<std::remove_pointer_t<Ptr>>, int> = 0>
decltype(auto) as_const(Src &&src) {
return std::forward<Src>(src);
Expand Down
5 changes: 4 additions & 1 deletion tests/unit_tests/fn/test_fn_sid_neighbor_table.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ namespace gridtools::fn {
using dim_hymap_t = hymap::keys<edge_dim_t, edge_to_cell_dim_t>;
auto contents = sid::synthetic()
.set<sid::property::origin>(sid::host_device::simple_ptr_holder(device_data.get()))
.set<sid::property::strides>(dim_hymap_t::make_values(num_neighbors, 1));
.set<sid::property::strides>(dim_hymap_t::make_values(num_neighbors, 1))
// for whatever reason, setting strides_kind is required
// by Clang-CUDA (tested Clang 17 + CUDA 12.4)
.set<sid::property::strides_kind, sid::unknown_kind>();

const auto table = as_neighbor_table<edge_dim_t, edge_to_cell_dim_t, num_neighbors>(contents);
using table_t = std::decay_t<decltype(table)>;
Expand Down
10 changes: 10 additions & 0 deletions tests/unit_tests/sid/test_sid_as_const.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,15 @@ namespace gridtools {
static_assert(std::is_same_v<sid::ptr_type<testee_t>, double const *>);
EXPECT_EQ(sid::get_origin(src)(), sid::get_origin(testee)());
}

TEST(as_const, c_array) {
int src[3][2] = {{0, 1}, {10, 11}, {20, 21}};
auto testee = sid::as_const(src);
using testee_t = decltype(testee);

static_assert(is_sid<testee_t>());
static_assert(std::is_same_v<sid::ptr_type<testee_t>, int const *>);
EXPECT_EQ(sid::get_origin(src)(), sid::get_origin(testee)());
}
} // namespace
} // namespace gridtools

0 comments on commit 85de9ff

Please sign in to comment.