Skip to content

Commit

Permalink
script_interface: Drop dependency on this_node global
Browse files Browse the repository at this point in the history
  • Loading branch information
jngrad committed Nov 21, 2021
1 parent 48357b3 commit 55bce86
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 13 deletions.
2 changes: 2 additions & 0 deletions src/script_interface/Context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class Context : public std::enable_shared_from_this<Context> {
*/
virtual boost::string_ref name(const ObjectHandle *o) const = 0;

virtual bool is_head_node() const = 0;

virtual ~Context() = default;
};
} // namespace ScriptInterface
Expand Down
3 changes: 2 additions & 1 deletion src/script_interface/ContextManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ std::string ContextManager::serialize(const ObjectHandle *o) const {

ContextManager::ContextManager(Communication::MpiCallbacks &callbacks,
const Utils::Factory<ObjectHandle> &factory) {
auto local_context = std::make_shared<LocalContext>(factory);
auto const mpi_rank = callbacks.comm().rank();
auto local_context = std::make_shared<LocalContext>(factory, mpi_rank);

/* If there is only one node, we can treat all objects as local, and thus
* never invoke any callback. */
Expand Down
7 changes: 6 additions & 1 deletion src/script_interface/GlobalContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class GlobalContext : public Context {

std::shared_ptr<LocalContext> m_node_local_context;

bool m_is_head_node;

private:
Communication::CallbackHandle<ObjectId, const std::string &,
const PackedMap &>
Expand All @@ -83,7 +85,8 @@ class GlobalContext : public Context {
public:
GlobalContext(Communication::MpiCallbacks &callbacks,
std::shared_ptr<LocalContext> node_local_context)
: m_node_local_context(std::move(node_local_context)),
: m_local_objects(), m_node_local_context(std::move(node_local_context)),
m_is_head_node(callbacks.comm().rank() == 0),
cb_make_handle(&callbacks,
[this](ObjectId id, const std::string &name,
const PackedMap &parameters) {
Expand Down Expand Up @@ -157,6 +160,8 @@ class GlobalContext : public Context {
make_shared(std::string const &name, const VariantMap &parameters) override;

boost::string_ref name(const ObjectHandle *o) const override;

bool is_head_node() const override { return m_is_head_node; };
};
} // namespace ScriptInterface

Expand Down
7 changes: 5 additions & 2 deletions src/script_interface/LocalContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ namespace ScriptInterface {
*/
class LocalContext : public Context {
Utils::Factory<ObjectHandle> m_factory;
bool m_is_head_node;

public:
explicit LocalContext(Utils::Factory<ObjectHandle> factory)
: m_factory(std::move(factory)) {}
explicit LocalContext(Utils::Factory<ObjectHandle> factory, int mpi_rank)
: m_factory(std::move(factory)), m_is_head_node(mpi_rank == 0) {}

const Utils::Factory<ObjectHandle> &factory() const { return m_factory; }

Expand All @@ -66,6 +67,8 @@ class LocalContext : public Context {

return factory().type_name(*o);
}

bool is_head_node() const override { return m_is_head_node; };
};
} // namespace ScriptInterface

Expand Down
3 changes: 1 addition & 2 deletions src/script_interface/interactions/BondedInteractions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include "BondedInteraction.hpp"

#include "core/communication.hpp"
#include "core/bonded_interactions/bonded_interaction_data.hpp"

#include "script_interface/ObjectMap.hpp"
Expand Down Expand Up @@ -87,7 +86,7 @@ class BondedInteractions : public ObjectMap<BondedInteraction> {
auto const bond_id = get_value<int>(params, "bond_id");
// core and script interface must agree
assert(m_bonds.count(bond_id) == ::bonded_ia_params.count(bond_id));
if (this_node != 0)
if (not context()->is_head_node())
return {};
// bond must exist
if (m_bonds.count(bond_id) == 0) {
Expand Down
3 changes: 1 addition & 2 deletions src/script_interface/lbboundaries/LBBoundary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

#include "config.hpp"

#include "core/communication.hpp"
#include "core/grid_based_algorithms/lb_interface.hpp"
#include "core/grid_based_algorithms/lbboundaries/LBBoundary.hpp"
#include "script_interface/ScriptInterface.hpp"
Expand Down Expand Up @@ -69,7 +68,7 @@ class LBBoundary : public AutoParameters<LBBoundary> {
Variant do_call_method(const std::string &name, const VariantMap &) override {
if (name == "get_force") {
// The get force method uses mpi callbacks on lb cpu
if (this_node == 0) {
if (context()->is_head_node()) {
const auto agrid = lb_lbfluid_get_agrid();
const auto tau = lb_lbfluid_get_tau();
const double unit_conversion = agrid / tau / tau;
Expand Down
4 changes: 4 additions & 0 deletions src/script_interface/object_container_mpi_guard.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#ifndef ESPRESSO_OBJECT_CONTAINER_MPI_GUARD__HPP
#define ESPRESSO_OBJECT_CONTAINER_MPI_GUARD__HPP

#include <boost/utility/string_ref.hpp>

Expand All @@ -40,3 +42,5 @@
*/
void object_container_mpi_guard(boost::string_ref const &name,
std::size_t n_elements);

#endif
2 changes: 1 addition & 1 deletion src/script_interface/tests/GlobalContext_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ auto make_global_context(Communication::MpiCallbacks &cb) {
factory.register_new<Dummy>("Dummy");

return std::make_shared<si::GlobalContext>(
cb, std::make_shared<si::LocalContext>(factory));
cb, std::make_shared<si::LocalContext>(factory, 0));
}

BOOST_AUTO_TEST_CASE(GlobalContext_make_shared) {
Expand Down
4 changes: 2 additions & 2 deletions src/script_interface/tests/LocalContext_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ auto factory = []() {
}();

BOOST_AUTO_TEST_CASE(LocalContext_make_shared) {
auto ctx = std::make_shared<si::LocalContext>(factory);
auto ctx = std::make_shared<si::LocalContext>(factory, 0);

auto res = ctx->make_shared("Dummy", {});
BOOST_REQUIRE(res != nullptr);
Expand All @@ -66,7 +66,7 @@ BOOST_AUTO_TEST_CASE(LocalContext_make_shared) {
}

BOOST_AUTO_TEST_CASE(LocalContext_serialization) {
auto ctx = std::make_shared<si::LocalContext>(factory);
auto ctx = std::make_shared<si::LocalContext>(factory, 0);

auto const serialized = [&]() {
auto d1 = ctx->make_shared("Dummy", {});
Expand Down
2 changes: 2 additions & 0 deletions src/script_interface/tests/ObjectHandle_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ struct LogContext : public Context {
boost::string_ref name(const ObjectHandle *o) const override {
return "Dummy";
}

bool is_head_node() const override { return true; };
};

/*
Expand Down
2 changes: 1 addition & 1 deletion src/script_interface/tests/ObjectList_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ BOOST_AUTO_TEST_CASE(serialization) {
Utils::Factory<ObjectHandle> f;
f.register_new<ObjectHandle>("ObjectHandle");
f.register_new<ObjectListImpl>("ObjectList");
auto ctx = std::make_shared<LocalContext>(f);
auto ctx = std::make_shared<LocalContext>(f, 0);
// A list of some elements
auto list = std::dynamic_pointer_cast<ObjectListImpl>(
ctx->make_shared("ObjectList", {}));
Expand Down
2 changes: 1 addition & 1 deletion src/script_interface/tests/ObjectMap_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ BOOST_AUTO_TEST_CASE(serialization) {
Utils::Factory<ObjectHandle> f;
f.register_new<ObjectHandle>("ObjectHandle");
f.register_new<ObjectMapImpl>("ObjectMap");
auto ctx = std::make_shared<LocalContext>(f);
auto ctx = std::make_shared<LocalContext>(f, 0);
// A list of some elements
auto map = std::dynamic_pointer_cast<ObjectMapImpl>(
ctx->make_shared("ObjectMap", {}));
Expand Down

0 comments on commit 55bce86

Please sign in to comment.