From 6d39239d577fc4d58e6b6ecda4229222ce9044d0 Mon Sep 17 00:00:00 2001 From: Thomas Padioleau Date: Fri, 11 Nov 2022 14:55:44 +0100 Subject: [PATCH] Fix empty loops (#89) --- include/ddc/for_each.hpp | 44 ++++++++++++++++++++++------------------ tests/for_each.cpp | 10 +++++++++ 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/include/ddc/for_each.hpp b/include/ddc/for_each.hpp index f70868307..1521d4351 100644 --- a/include/ddc/for_each.hpp +++ b/include/ddc/for_each.hpp @@ -37,10 +37,12 @@ class ForEachKokkosLambdaAdapter template inline void for_each_kokkos(DiscreteDomain const& domain, Functor const& f) noexcept { + DiscreteElement const ddc_begin = domain.front(); + DiscreteElement const ddc_end = domain.front() + domain.extents(); + std::size_t const begin = ddc::uid(ddc_begin); + std::size_t const end = ddc::uid(ddc_end); Kokkos::parallel_for( - Kokkos::RangePolicy( - select(domain).front().uid(), - select(domain).back().uid() + 1), + Kokkos::RangePolicy(begin, end), ForEachKokkosLambdaAdapter(f)); } @@ -49,14 +51,14 @@ inline void for_each_kokkos( DiscreteDomain const& domain, Functor&& f) noexcept { + DiscreteElement const ddc_begin = domain.front(); + DiscreteElement const ddc_end = domain.front() + domain.extents(); Kokkos::Array const - begin {select(domain).front().uid(), - select(domain).front().uid(), - select(domain).front().uid()...}; + begin {ddc::uid(ddc_begin), + ddc::uid(ddc_begin), + ddc::uid(ddc_begin)...}; Kokkos::Array const - end {(select(domain).back().uid() + 1), - (select(domain).back().uid() + 1), - (select(domain).back().uid() + 1)...}; + end {ddc::uid(ddc_end), ddc::uid(ddc_end), ddc::uid(ddc_end)...}; Kokkos::parallel_for( Kokkos::MDRangePolicy< ExecSpace, @@ -69,7 +71,7 @@ inline void for_each_kokkos( template inline void for_each_serial( - std::array const& start, + std::array const& begin, std::array const& end, Functor const& f, Is const&... is) noexcept @@ -78,8 +80,8 @@ inline void for_each_serial( if constexpr (I == N) { f(RetType(is...)); } else { - for (Element ii = start[I]; ii <= end[I]; ++ii) { - for_each_serial(start, end, f, is..., ii); + for (Element ii = begin[I]; ii < end[I]; ++ii) { + for_each_serial(begin, end, f, is..., ii); } } } @@ -101,10 +103,11 @@ inline void for_each( DiscreteDomain const& domain, Functor&& f) noexcept { - ddc_detail::for_each_serial>( - ddc_detail::array(domain.front()), - ddc_detail::array(domain.back()), - std::forward(f)); + DiscreteElement const ddc_begin = domain.front(); + DiscreteElement const ddc_end = domain.front() + domain.extents(); + std::array const begin = ddc_detail::array(ddc_begin); + std::array const end = ddc_detail::array(ddc_end); + ddc_detail::for_each_serial>(begin, end, std::forward(f)); } /** iterates over a nD extent using the serial execution policy @@ -117,10 +120,11 @@ inline void for_each_n( DiscreteVector const& extent, Functor&& f) noexcept { - ddc_detail::for_each_serial>( - std::array {}, - std::array {get(extent) - 1 ...}, - std::forward(f)); + DiscreteVector const ddc_begin {}; + DiscreteVector const ddc_end = extent; + std::array const begin = ddc_detail::array(ddc_begin); + std::array const end = ddc_detail::array(ddc_end); + ddc_detail::for_each_serial>(begin, end, std::forward(f)); } /// Parallel execution on the default device diff --git a/tests/for_each.cpp b/tests/for_each.cpp index bb2ee44b2..e07a49e6c 100644 --- a/tests/for_each.cpp +++ b/tests/for_each.cpp @@ -27,6 +27,16 @@ static DVectY constexpr nelems_y(12); static DElemXY constexpr lbound_x_y {lbound_x, lbound_y}; static DVectXY constexpr nelems_x_y(nelems_x, nelems_y); +TEST(ForEachSerialHost, Empty) +{ + DDomX const dom(lbound_x, DVectX(0)); + std::vector storage(dom.size(), 0); + ddc::ChunkSpan view(storage.data(), dom); + ddc::for_each(ddc::policies::serial_host, dom, [=](DElemX const ix) { view(ix) += 1; }); + ASSERT_EQ(std::count(storage.begin(), storage.end(), 1), dom.size()); + std::cout << std::count(storage.begin(), storage.end(), 1) << std::endl; +} + TEST(ForEachSerialHost, OneDimension) { DDomX const dom(lbound_x, nelems_x);