Skip to content

Commit

Permalink
Fix Utils::Mpi::gather_buffer() functions (#4075)
Browse files Browse the repository at this point in the history
Fixes #4074

Description of changes:
- fix a bug that caused buffers to overwrite each others in `Utils::Mpi::gather_buffer()` when argument `root` was not 0
- fix two broken unit tests for `Utils::Mpi` functions
  • Loading branch information
kodiakhq[bot] authored Jan 14, 2021
2 parents 979acef + fc34689 commit 46372ec
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 15 deletions.
43 changes: 33 additions & 10 deletions src/utils/include/utils/mpi/gather_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,33 @@

namespace Utils {
namespace Mpi {
namespace detail {
template <typename T>
void relocate_data(T *buffer, std::vector<int> const &sizes,
std::vector<int> const &displ, int root) {
if (sizes[root] && displ[root]) {
for (int i = sizes[root] - 1; i >= 0; --i) {
buffer[i + displ[root]] = buffer[i];
}
}
}
} // namespace detail

/**
* @brief Gather buffer with different size on each node.
*
* Gathers buffers with different lengths from all nodes to root.
* The buffer is assumed to be large enough to hold the data from
* all the nodes and is owned by the caller. On the root node no
* data is copied, and the first n_elem elements of buffer are not
* touched. This combines a common combination of MPI_Gather and
* MPI_{Send,Recv}.
* all the nodes and is owned by the caller. On the @p root node,
* the first @p n_elem elements of @p buffer are moved, if need
* be. On the other nodes, @p buffer is not touched.
*
* This encapsulates a common combination of <tt>MPI_Gather()</tt>
* and <tt>MPI_{Send,Recv}()</tt>.
*
* @param buffer On the master the target buffer that has to be
large enough to hold all elements and has the local
part in the beginning. On the slaves the local buffer.
* large enough to hold all elements and has the local
* part in the beginning. On the slaves the local buffer.
* @param n_elem The number of elements in the local buffer.
* @param comm The MPI communicator.
* @param root The rank where the data should be gathered.
Expand All @@ -63,11 +76,15 @@ int gather_buffer(T *buffer, int n_elem, boost::mpi::communicator comm,
auto const total_size =
detail::size_and_offset<T>(sizes, displ, n_elem, comm, root);

/* Move the original data to its new location */
detail::relocate_data(buffer, sizes, displ, root);

/* Gather data */
gatherv(comm, buffer, 0, buffer, sizes.data(), displ.data(), root);

return total_size;
}
/* Send local size */
detail::size_and_offset(n_elem, comm, root);
/* Send data */
gatherv(comm, buffer, n_elem, static_cast<T *>(nullptr), nullptr, nullptr,
Expand All @@ -80,12 +97,15 @@ int gather_buffer(T *buffer, int n_elem, boost::mpi::communicator comm,
* @brief Gather buffer with different size on each node.
*
* Gathers buffers with different lengths from all nodes to root.
* The buffer is resized to the total size. On the root node no
* data is copied, and the first n_elem elements of buffer are not
* touched. On the slaves, the buffer is not touched.
* The buffer is resized to the total size. On the @p root node,
* the first @p n_elem elements of @p buffer are moved, if need
* be. On the other nodes, @p buffer is not touched.
*
* This encapsulates a common combination of <tt>MPI_Gather()</tt>
* and <tt>MPI_{Send,Recv}()</tt>.
*
* @param buffer On the master the target buffer that has the local
part in the beginning. On the slaves the local buffer.
* part in the beginning. On the slaves the local buffer.
* @param comm The MPI communicator.
* @param root The rank where the data should be gathered.
*/
Expand All @@ -104,6 +124,9 @@ void gather_buffer(std::vector<T, Allocator> &buffer,
/* Resize the buffer */
buffer.resize(tot_size);

/* Move the original data to its new location */
detail::relocate_data(buffer.data(), sizes, displ, root);

/* Gather data */
gatherv(comm, buffer.data(), buffer.size(), buffer.data(), sizes.data(),
displ.data(), root);
Expand Down
2 changes: 1 addition & 1 deletion src/utils/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ unit_test(NAME scatter_buffer_test SRC scatter_buffer_test.cpp DEPENDS
unit_test(NAME all_compare_test SRC all_compare_test.cpp DEPENDS EspressoUtils
Boost::mpi MPI::MPI_CXX NUM_PROC 3)
unit_test(NAME gatherv_test SRC gatherv_test.cpp DEPENDS EspressoUtils
Boost::mpi MPI::MPI_CXX)
Boost::mpi MPI::MPI_CXX NUM_PROC 3)
unit_test(NAME all_gatherv_test SRC all_gatherv_test.cpp DEPENDS EspressoUtils
Boost::mpi MPI::MPI_CXX)
unit_test(NAME sendrecv_test SRC sendrecv_test.cpp DEPENDS EspressoUtils
Expand Down
75 changes: 72 additions & 3 deletions src/utils/tests/gather_buffer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,51 @@ void check_vector(const mpi::communicator &comm, int root) {
}
}

void check_vector_out_of_bounds(const mpi::communicator &comm) {
/* Check that moving data in the buffer on the root doesn't lead
* to an access out of bounds (using assertions from std::vector) */
const auto root = 1;
if (comm.rank() == 1) {
std::vector<int> buf = {2, 2};
gather_buffer(buf, comm, root);
BOOST_CHECK(buf.size() == 3);
BOOST_CHECK(buf[0] == 1);
BOOST_CHECK(buf[1] == 2);
BOOST_CHECK(buf[2] == 2);
} else if (comm.rank() == 0) {
std::vector<int> buf = {1};
gather_buffer(buf, comm, root);
BOOST_CHECK(buf.size() == 1);
BOOST_CHECK(buf[0] == 1);
} else {
std::vector<int> buf = {};
gather_buffer(buf, comm, root);
BOOST_CHECK(buf.empty());
}
}

void check_pointer_out_of_bounds(const mpi::communicator &comm) {
/* Check that moving data in the buffer on the root doesn't lead
* to an access out of bounds (using a sentinel value) */
const auto root = 1;
if (comm.rank() == 1) {
std::vector<int> buf = {2, 2, 0, -1};
gather_buffer(buf.data(), 2, comm, root);
BOOST_CHECK(buf.size() == 4);
BOOST_CHECK(buf[0] == 1);
BOOST_CHECK(buf[1] == 2);
BOOST_CHECK(buf[2] == 2);
BOOST_CHECK(buf[3] == -1);
} else if (comm.rank() == 0) {
std::vector<int> buf = {1};
gather_buffer(buf.data(), 1, comm, root);
BOOST_CHECK(buf[0] == 1);
} else {
std::vector<int> buf = {};
gather_buffer(buf.data(), 0, comm, root);
}
}

void check_vector_empty(const mpi::communicator &comm, int empty) {
std::vector<int> buf((comm.rank() == empty) ? 0 : 11, comm.rank());
gather_buffer(buf, comm);
Expand Down Expand Up @@ -149,6 +194,18 @@ BOOST_AUTO_TEST_CASE(pointer) {
check_pointer(world, 0);
}

BOOST_AUTO_TEST_CASE(pointer_overlap) {
mpi::communicator world;
if (world.size() >= 2)
check_pointer(world, 1);
}

BOOST_AUTO_TEST_CASE(pointer_out_of_bounds) {
mpi::communicator world;
if (world.size() >= 2)
check_pointer_out_of_bounds(world);
}

BOOST_AUTO_TEST_CASE(pointer_root) {
mpi::communicator world;

Expand All @@ -158,14 +215,26 @@ BOOST_AUTO_TEST_CASE(pointer_root) {

BOOST_AUTO_TEST_CASE(vector) {
mpi::communicator world;
check_pointer(world, 0);
check_vector(world, 0);
}

BOOST_AUTO_TEST_CASE(vector_overlap) {
mpi::communicator world;
if (world.size() >= 2)
check_vector(world, 1);
}

BOOST_AUTO_TEST_CASE(vector_out_of_bounds) {
mpi::communicator world;
if (world.size() >= 2)
check_vector_out_of_bounds(world);
}

BOOST_AUTO_TEST_CASE(vector_root) {
mpi::communicator world;

auto empty = (world.size() >= 3) ? world.size() - 2 : world.size() - 1;
check_pointer(world, empty);
auto root = (world.size() >= 3) ? world.size() - 2 : world.size() - 1;
check_vector(world, root);
}

BOOST_AUTO_TEST_CASE(vector_empty) {
Expand Down
2 changes: 1 addition & 1 deletion src/utils/tests/gatherv_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

#define BOOST_TEST_NO_MAIN
#define BOOST_TEST_MODULE Utils::Mpi::gatherv_test test
#define BOOST_TEST_MODULE Utils::Mpi::gatherv test
#define BOOST_TEST_DYN_LINK
#include <boost/test/unit_test.hpp>

Expand Down

0 comments on commit 46372ec

Please sign in to comment.