Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove MPI static globals #4858

Merged
merged 2 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/push_pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ permissions:
jobs:
macos:
runs-on: macos-12
if: false
if: ${{ github.repository == 'espressomd/espresso' }}
steps:
- name: Checkout
uses: actions/checkout@main
Expand Down
3 changes: 0 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,6 @@ if(ESPRESSO_BUILD_TESTS)
endif()

find_package(Boost 1.74.0 REQUIRED ${BOOST_COMPONENTS})
if(${Boost_VERSION} VERSION_GREATER_EQUAL 1.84.0)
message(FATAL_ERROR "Boost version ${Boost_VERSION} is unsupported.")
endif()

#
# Paths
Expand Down
31 changes: 17 additions & 14 deletions src/core/MpiCallbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

#include <boost/mpi/collectives/broadcast.hpp>
#include <boost/mpi/communicator.hpp>
#include <boost/mpi/environment.hpp>
#include <boost/mpi/packed_iarchive.hpp>
#include <boost/range/algorithm/remove_if.hpp>

Expand Down Expand Up @@ -201,8 +202,8 @@ class MpiCallbacks {
template <typename F, class = std::enable_if_t<std::is_same_v<
typename detail::functor_types<F>::argument_types,
std::tuple<Args...>>>>
CallbackHandle(MpiCallbacks *cb, F &&f)
: m_id(cb->add(std::forward<F>(f))), m_cb(cb) {}
CallbackHandle(std::shared_ptr<MpiCallbacks> cb, F &&f)
: m_id(cb->add(std::forward<F>(f))), m_cb(std::move(cb)) {}

CallbackHandle(CallbackHandle const &) = delete;
CallbackHandle(CallbackHandle &&rhs) noexcept = default;
Expand All @@ -211,7 +212,7 @@ class MpiCallbacks {

private:
int m_id;
MpiCallbacks *m_cb;
std::shared_ptr<MpiCallbacks> m_cb;

public:
/**
Expand All @@ -237,7 +238,6 @@ class MpiCallbacks {
m_cb->remove(m_id);
}

MpiCallbacks *cb() const { return m_cb; }
int id() const { return m_id; }
};

Expand All @@ -255,9 +255,9 @@ class MpiCallbacks {
}

public:
explicit MpiCallbacks(boost::mpi::communicator comm,
bool abort_on_exit = true)
: m_abort_on_exit(abort_on_exit), m_comm(std::move(comm)) {
MpiCallbacks(boost::mpi::communicator comm,
std::shared_ptr<boost::mpi::environment> mpi_env)
: m_comm(std::move(comm)), m_mpi_env(std::move(mpi_env)) {
/* Add a dummy at id 0 for loop abort. */
m_callback_map.add(nullptr);

Expand All @@ -268,7 +268,7 @@ class MpiCallbacks {

~MpiCallbacks() {
/* Release the clients on exit */
if (m_abort_on_exit && (m_comm.rank() == 0)) {
if (m_comm.rank() == 0) {
try {
abort_loop();
} catch (...) {
Expand Down Expand Up @@ -447,22 +447,25 @@ class MpiCallbacks {
*/
boost::mpi::communicator const &comm() const { return m_comm; }

std::shared_ptr<boost::mpi::environment> share_mpi_env() const {
return m_mpi_env;
}

private:
/**
* @brief Id for the @ref abort_loop. Has to be 0.
*/
enum { LOOP_ABORT = 0 };
static constexpr int LOOP_ABORT = 0;

/**
* @brief If @ref abort_loop should be called on destruction
* on the head node.
* The MPI communicator used for the callbacks.
*/
bool m_abort_on_exit;
boost::mpi::communicator m_comm;

/**
* The MPI communicator used for the callbacks.
* The MPI environment used for the callbacks.
*/
boost::mpi::communicator m_comm;
std::shared_ptr<boost::mpi::environment> m_mpi_env;

/**
* Internal storage for the callback functions.
Expand Down
20 changes: 12 additions & 8 deletions src/core/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

#include <boost/mpi.hpp>
#include <boost/mpi/communicator.hpp>
#include <boost/mpi/environment.hpp>

#include <mpi.h>

Expand All @@ -47,17 +48,20 @@ boost::mpi::communicator comm_cart;
Communicator communicator{};

namespace Communication {
static auto const &mpi_datatype_cache =
boost::mpi::detail::mpi_datatype_cache();
static std::shared_ptr<boost::mpi::environment> mpi_env;
static std::unique_ptr<MpiCallbacks> m_callbacks;
static std::shared_ptr<MpiCallbacks> m_callbacks;

/* We use a singleton callback class for now. */
MpiCallbacks &mpiCallbacks() {
assert(m_callbacks && "Mpi not initialized!");

return *m_callbacks;
}

std::shared_ptr<MpiCallbacks> mpiCallbacksHandle() {
assert(m_callbacks && "Mpi not initialized!");

return m_callbacks;
}
} // namespace Communication

using Communication::mpiCallbacks;
Expand All @@ -66,14 +70,12 @@ int this_node = -1;

namespace Communication {
void init(std::shared_ptr<boost::mpi::environment> mpi_env) {
Communication::mpi_env = std::move(mpi_env);

communicator.full_initialization();

Communication::m_callbacks =
std::make_unique<Communication::MpiCallbacks>(comm_cart);
std::make_shared<Communication::MpiCallbacks>(comm_cart, mpi_env);

ErrorHandling::init_error_handling(mpiCallbacks());
ErrorHandling::init_error_handling(Communication::m_callbacks);

#ifdef WALBERLA
walberla::mpi_init();
Expand All @@ -83,6 +85,8 @@ void init(std::shared_ptr<boost::mpi::environment> mpi_env) {
cuda_on_program_start();
#endif
}

void deinit() { Communication::m_callbacks.reset(); }
} // namespace Communication

Communicator::Communicator()
Expand Down
15 changes: 11 additions & 4 deletions src/core/communication.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ namespace Communication {
* @brief Returns a reference to the global callback class instance.
*/
MpiCallbacks &mpiCallbacks();
std::shared_ptr<MpiCallbacks> mpiCallbacksHandle();
} // namespace Communication

/**************************************************
Expand Down Expand Up @@ -124,12 +125,18 @@ namespace Communication {
/**
* @brief Init globals for communication.
*
* and calls @ref cuda_on_program_start. Keeps a copy of
* the pointer to the mpi environment to keep it alive
* while the program is loaded.
*
* @param mpi_env MPI environment that should be used
*/
void init(std::shared_ptr<boost::mpi::environment> mpi_env);
void deinit();
} // namespace Communication

struct MpiContainerUnitTest {
std::shared_ptr<boost::mpi::environment> m_mpi_env;
MpiContainerUnitTest(int argc, char **argv) {
m_mpi_env = mpi_init(argc, argv);
Communication::init(m_mpi_env);
}
~MpiContainerUnitTest() { Communication::deinit(); }
};
#endif
13 changes: 7 additions & 6 deletions src/core/errorhandling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace ErrorHandling {
Expand All @@ -44,13 +45,13 @@ namespace ErrorHandling {
static std::unique_ptr<RuntimeErrorCollector> runtimeErrorCollector;

/** The callback loop we are on. */
static Communication::MpiCallbacks *m_callbacks = nullptr;
static std::weak_ptr<Communication::MpiCallbacks> m_callbacks;

void init_error_handling(Communication::MpiCallbacks &cb) {
m_callbacks = &cb;
void init_error_handling(std::weak_ptr<Communication::MpiCallbacks> callbacks) {
m_callbacks = std::move(callbacks);

runtimeErrorCollector =
std::make_unique<RuntimeErrorCollector>(m_callbacks->comm());
std::make_unique<RuntimeErrorCollector>(m_callbacks.lock()->comm());
}

RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level,
Expand All @@ -67,7 +68,7 @@ static void mpi_gather_runtime_errors_local() {
REGISTER_CALLBACK(mpi_gather_runtime_errors_local)

std::vector<RuntimeError> mpi_gather_runtime_errors() {
m_callbacks->call(mpi_gather_runtime_errors_local);
m_callbacks.lock()->call(mpi_gather_runtime_errors_local);
return runtimeErrorCollector->gather();
}

Expand All @@ -81,7 +82,7 @@ std::vector<RuntimeError> mpi_gather_runtime_errors_all(bool is_head_node) {
} // namespace ErrorHandling

void errexit() {
ErrorHandling::m_callbacks->comm().abort(1);
ErrorHandling::m_callbacks.lock()->comm().abort(1);

std::abort();
}
Expand Down
3 changes: 2 additions & 1 deletion src/core/errorhandling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "error_handling/RuntimeError.hpp"
#include "error_handling/RuntimeErrorStream.hpp"

#include <memory>
#include <string>
#include <vector>

Expand Down Expand Up @@ -85,7 +86,7 @@ namespace ErrorHandling {
*
* @param callbacks Callbacks system the error handler should be on.
*/
void init_error_handling(Communication::MpiCallbacks &callbacks);
void init_error_handling(std::weak_ptr<Communication::MpiCallbacks> callbacks);

RuntimeErrorStream _runtimeMessageStream(RuntimeError::ErrorLevel level,
const std::string &file, int line,
Expand Down
2 changes: 1 addition & 1 deletion src/core/reaction_methods/tests/ReactionAlgorithm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ BOOST_FIXTURE_TEST_CASE(ReactionAlgorithm_test, ParticleFactory) {
}

int main(int argc, char **argv) {
mpi_init_stand_alone(argc, argv);
auto const mpi_handle = MpiContainerUnitTest(argc, argv);
espresso::system = System::System::create();
espresso::system->set_cell_structure_topology(CellStructureType::REGULAR);
::System::set_system(espresso::system);
Expand Down
7 changes: 0 additions & 7 deletions src/core/system/System.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,3 @@ unsigned System::get_global_ghost_flags() const {
}

} // namespace System

void mpi_init_stand_alone(int argc, char **argv) {
auto mpi_env = mpi_init(argc, argv);

// initialize the MpiCallbacks framework
Communication::init(mpi_env);
}
7 changes: 0 additions & 7 deletions src/core/system/System.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,3 @@ void reset_system();
bool is_system_set();

} // namespace System

/**
* @brief Initialize MPI global state to run ESPResSo in stand-alone mode.
* Use this function in simulations written in C++, such as unit tests.
* The script interface has its own MPI initialization mechanism.
*/
void mpi_init_stand_alone(int argc, char **argv);
2 changes: 1 addition & 1 deletion src/core/unit_tests/EspressoSystemStandAlone_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ BOOST_FIXTURE_TEST_CASE(espresso_system_stand_alone, ParticleFactory) {
}

int main(int argc, char **argv) {
mpi_init_stand_alone(argc, argv);
auto const mpi_handle = MpiContainerUnitTest(argc, argv);
espresso::system = System::System::create();
espresso::system->set_cell_structure_topology(CellStructureType::REGULAR);
::System::set_system(espresso::system);
Expand Down
2 changes: 1 addition & 1 deletion src/core/unit_tests/EspressoSystem_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ BOOST_FIXTURE_TEST_CASE(check_with_gpu, ParticleFactory,
}

int main(int argc, char **argv) {
mpi_init_stand_alone(argc, argv);
auto const mpi_handle = MpiContainerUnitTest(argc, argv);

return boost::unit_test::unit_test_main(init_unit_test, argc, argv);
}
Expand Down
Loading
Loading