Skip to content

Commit

Permalink
Fix empty loops (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
tpadioleau authored Nov 11, 2022
1 parent e0c1806 commit 6d39239
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
44 changes: 24 additions & 20 deletions include/ddc/for_each.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ class ForEachKokkosLambdaAdapter
template <class ExecSpace, class Functor, class DDim0>
inline void for_each_kokkos(DiscreteDomain<DDim0> const& domain, Functor const& f) noexcept
{
DiscreteElement<DDim0> const ddc_begin = domain.front();
DiscreteElement<DDim0> const ddc_end = domain.front() + domain.extents();
std::size_t const begin = ddc::uid<DDim0>(ddc_begin);
std::size_t const end = ddc::uid<DDim0>(ddc_end);
Kokkos::parallel_for(
Kokkos::RangePolicy<ExecSpace>(
select<DDim0>(domain).front().uid(),
select<DDim0>(domain).back().uid() + 1),
Kokkos::RangePolicy<ExecSpace>(begin, end),
ForEachKokkosLambdaAdapter<Functor, DDim0>(f));
}

Expand All @@ -49,14 +51,14 @@ inline void for_each_kokkos(
DiscreteDomain<DDim0, DDim1, DDims...> const& domain,
Functor&& f) noexcept
{
DiscreteElement<DDim0, DDim1, DDims...> const ddc_begin = domain.front();
DiscreteElement<DDim0, DDim1, DDims...> const ddc_end = domain.front() + domain.extents();
Kokkos::Array<std::size_t, 2 + sizeof...(DDims)> const
begin {select<DDim0>(domain).front().uid(),
select<DDim1>(domain).front().uid(),
select<DDims>(domain).front().uid()...};
begin {ddc::uid<DDim0>(ddc_begin),
ddc::uid<DDim1>(ddc_begin),
ddc::uid<DDims>(ddc_begin)...};
Kokkos::Array<std::size_t, 2 + sizeof...(DDims)> const
end {(select<DDim0>(domain).back().uid() + 1),
(select<DDim1>(domain).back().uid() + 1),
(select<DDims>(domain).back().uid() + 1)...};
end {ddc::uid<DDim0>(ddc_end), ddc::uid<DDim1>(ddc_end), ddc::uid<DDims>(ddc_end)...};
Kokkos::parallel_for(
Kokkos::MDRangePolicy<
ExecSpace,
Expand All @@ -69,7 +71,7 @@ inline void for_each_kokkos(

template <class RetType, class Element, std::size_t N, class Functor, class... Is>
inline void for_each_serial(
std::array<Element, N> const& start,
std::array<Element, N> const& begin,
std::array<Element, N> const& end,
Functor const& f,
Is const&... is) noexcept
Expand All @@ -78,8 +80,8 @@ inline void for_each_serial(
if constexpr (I == N) {
f(RetType(is...));
} else {
for (Element ii = start[I]; ii <= end[I]; ++ii) {
for_each_serial<RetType>(start, end, f, is..., ii);
for (Element ii = begin[I]; ii < end[I]; ++ii) {
for_each_serial<RetType>(begin, end, f, is..., ii);
}
}
}
Expand All @@ -101,10 +103,11 @@ inline void for_each(
DiscreteDomain<DDims...> const& domain,
Functor&& f) noexcept
{
ddc_detail::for_each_serial<DiscreteElement<DDims...>>(
ddc_detail::array(domain.front()),
ddc_detail::array(domain.back()),
std::forward<Functor>(f));
DiscreteElement<DDims...> const ddc_begin = domain.front();
DiscreteElement<DDims...> const ddc_end = domain.front() + domain.extents();
std::array const begin = ddc_detail::array(ddc_begin);
std::array const end = ddc_detail::array(ddc_end);
ddc_detail::for_each_serial<DiscreteElement<DDims...>>(begin, end, std::forward<Functor>(f));
}

/** iterates over a nD extent using the serial execution policy
Expand All @@ -117,10 +120,11 @@ inline void for_each_n(
DiscreteVector<DDims...> const& extent,
Functor&& f) noexcept
{
ddc_detail::for_each_serial<DiscreteVector<DDims...>>(
std::array<DiscreteVectorElement, sizeof...(DDims)> {},
std::array<DiscreteVectorElement, sizeof...(DDims)> {get<DDims>(extent) - 1 ...},
std::forward<Functor>(f));
DiscreteVector<DDims...> const ddc_begin {};
DiscreteVector<DDims...> const ddc_end = extent;
std::array const begin = ddc_detail::array(ddc_begin);
std::array const end = ddc_detail::array(ddc_end);
ddc_detail::for_each_serial<DiscreteVector<DDims...>>(begin, end, std::forward<Functor>(f));
}

/// Parallel execution on the default device
Expand Down
10 changes: 10 additions & 0 deletions tests/for_each.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ static DVectY constexpr nelems_y(12);
static DElemXY constexpr lbound_x_y {lbound_x, lbound_y};
static DVectXY constexpr nelems_x_y(nelems_x, nelems_y);

TEST(ForEachSerialHost, Empty)
{
DDomX const dom(lbound_x, DVectX(0));
std::vector<int> storage(dom.size(), 0);
ddc::ChunkSpan<int, DDomX> view(storage.data(), dom);
ddc::for_each(ddc::policies::serial_host, dom, [=](DElemX const ix) { view(ix) += 1; });
ASSERT_EQ(std::count(storage.begin(), storage.end(), 1), dom.size());
std::cout << std::count(storage.begin(), storage.end(), 1) << std::endl;
}

TEST(ForEachSerialHost, OneDimension)
{
DDomX const dom(lbound_x, nelems_x);
Expand Down

0 comments on commit 6d39239

Please sign in to comment.