Skip to content

Commit

Permalink
Script interface maintenance (#4394)
Browse files Browse the repository at this point in the history
Description of changes:
- allow `@property`-setters in `ScriptInterfaceHelper` classes instead of silently skipping them
- unit test the `ScriptInterfaceHelper` class in python
- re-enable benchmark tests
- fix regressions introduced by #4350
- store bond objects in the ObjectMap (see #4391 (comment))
- stop using the core global variable `this_node` in the script interface
  • Loading branch information
kodiakhq[bot] authored Nov 23, 2021
2 parents b801493 + 55bce86 commit f463f42
Show file tree
Hide file tree
Showing 22 changed files with 214 additions and 78 deletions.
4 changes: 4 additions & 0 deletions maintainer/CI/build_cmake.sh
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ cmake_params="-DCMAKE_BUILD_TYPE=${build_type} -DCMAKE_CXX_STANDARD=${with_cxx_s
cmake_params="${cmake_params} -DCMAKE_INSTALL_PREFIX=/tmp/espresso-unit-tests"
cmake_params="${cmake_params} -DCTEST_ARGS=-j${check_procs} -DTEST_TIMEOUT=${test_timeout}"

if [ "${make_check_benchmarks}" = true ]; then
cmake_params="${cmake_params} -DWITH_BENCHMARKS=ON"
fi

if [ "${with_ccache}" = true ]; then
cmake_params="${cmake_params} -DWITH_CCACHE=ON"
fi
Expand Down
2 changes: 1 addition & 1 deletion maintainer/benchmarks/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ python_benchmark(FILE lj.py ARGUMENTS
"--particles_per_core=10000;--volume_fraction=0.50")
python_benchmark(FILE lj.py ARGUMENTS
"--particles_per_core=10000;--volume_fraction=0.02")
python_benchmark(FILE MC-acid-base-reservoir.py ARGUMENTS
python_benchmark(FILE mc_acid_base_reservoir.py ARGUMENTS
"--particles_per_core=500;--mode=benchmark")
python_benchmark(
FILE lj.py ARGUMENTS
Expand Down
2 changes: 1 addition & 1 deletion src/python/espressomd/collision_detection.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class CollisionDetection(ScriptInterfaceHelper):
if name == "mode":
res = self._str_mode(value)

# Convert bond parameters from bond ids to into BondedInteractions
# Get bonded interaction
if name in ["bond_centers", "bond_vs", "bond_three_particle_binding"]:
if value == -1: # Not defined
res = None
Expand Down
108 changes: 65 additions & 43 deletions src/python/espressomd/interactions.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1672,38 +1672,39 @@ class BondedInteraction(ScriptInterfaceHelper):

def __init__(self, *args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
# this branch is only visited by checkpointing constructor #2
kwargs = args[0]
args = []

# Interaction id as argument
if len(args) == 1 and is_valid_type(args[0], int):
bond_id = args[0]

# Check if the bond type in ESPResSo core matches this class
if get_bonded_interaction_type_from_es_core(
bond_id) != self.type_number():
raise Exception(
"The bond with this id is not defined as a " + self.type_name() + " bond in the ESPResSo core.")

super().__init__(bond_id=bond_id)
self._bond_id = bond_id

# Load the parameters currently set in the ESPResSo core
self._ctor_params = self._get_params_from_es_core()

# Or have we been called with keyword args describing the interaction
elif len(args) == 0:
params = self.get_default_params()
params.update(kwargs)
self.validate_params(params)
super().__init__(*args, **params)
self._check_keys(params.keys(), check_required=True)
self._ctor_params = params
self._bond_id = -1

if not 'sip' in kwargs:
if len(args) == 1 and is_valid_type(args[0], int):
# create a new script interface object for a bond that already
# exists in the core via bond_id (checkpointing constructor #1)
bond_id = args[0]
# Check if the bond type in ESPResSo core matches this class
if get_bonded_interaction_type_from_es_core(
bond_id) != self.type_number():
raise Exception(
f"The bond with this id is not defined as a "
f"{self.type_name()} bond in the ESPResSo core.")
super().__init__(bond_id=bond_id)
self._bond_id = bond_id
self._ctor_params = self._get_params_from_es_core()
else:
# create a bond from bond parameters
params = self.get_default_params()
params.update(kwargs)
self.validate_params(params)
super().__init__(*args, **params)
self._check_keys(params.keys(), check_required=True)
self._ctor_params = params
self._bond_id = -1
else:
raise Exception(
"The constructor has to be called either with a bond id (as integer), or with a set of keyword arguments describing a new interaction")
# create a new bond based on a bond in the script interface
# (checkpointing constructor #2 or BondedInteractions getter)
super().__init__(**kwargs)
self._bond_id = -1
self._ctor_params = self._get_params_from_es_core()

def _check_keys(self, keys, check_required=False):
def err_msg(key_set):
Expand All @@ -1722,8 +1723,10 @@ class BondedInteraction(ScriptInterfaceHelper):

def __reduce__(self):
if self._bond_id != -1:
# checkpointing constructor #1
return (self.__class__, (self._bond_id,))
else:
# checkpointing constructor #2
return (self.__class__, (self._ctor_params,))

def __setattr__(self, attr, value):
Expand All @@ -1735,7 +1738,7 @@ class BondedInteraction(ScriptInterfaceHelper):

@params.setter
def params(self, p):
raise Exception("Bonds are immutable.")
raise RuntimeError("Bond parameters are immutable.")

def validate_params(self, params):
"""Check that parameters are valid.
Expand Down Expand Up @@ -1911,6 +1914,7 @@ class HarmonicBond(BondedInteraction):

if ELECTROSTATICS:

@script_interface_register
class BondedCoulomb(BondedInteraction):

"""
Expand Down Expand Up @@ -1971,6 +1975,7 @@ if ELECTROSTATICS:
return {}


@script_interface_register
class ThermalizedBond(BondedInteraction):

"""
Expand Down Expand Up @@ -2100,6 +2105,7 @@ IF THOLE:

IF BOND_CONSTRAINT == 1:

@script_interface_register
class RigidBond(BondedInteraction):

"""
Expand Down Expand Up @@ -2142,6 +2148,7 @@ ELSE:
name = "RIGID"


@script_interface_register
class Dihedral(BondedInteraction):

"""
Expand Down Expand Up @@ -2223,6 +2230,7 @@ IF TABULATED:
"""
return {}

@script_interface_register
class TabulatedDistance(_TabulatedBase):

"""
Expand Down Expand Up @@ -2256,6 +2264,7 @@ IF TABULATED:
"""
return "TABULATED_DISTANCE"

@script_interface_register
class TabulatedAngle(_TabulatedBase):

"""
Expand Down Expand Up @@ -2298,6 +2307,7 @@ IF TABULATED:
raise ValueError(f"Tabulated angle expects forces/energies "
f"within the range [0, pi], got {phi}")

@script_interface_register
class TabulatedDihedral(_TabulatedBase):

"""
Expand Down Expand Up @@ -2419,6 +2429,7 @@ IF TABULATED:
return True


@script_interface_register
class Virtual(BondedInteraction):

"""
Expand Down Expand Up @@ -2446,6 +2457,7 @@ class Virtual(BondedInteraction):
return {}


@script_interface_register
class AngleHarmonic(BondedInteraction):

"""
Expand Down Expand Up @@ -2478,6 +2490,7 @@ class AngleHarmonic(BondedInteraction):
return {}


@script_interface_register
class AngleCosine(BondedInteraction):

"""
Expand Down Expand Up @@ -2510,6 +2523,7 @@ class AngleCosine(BondedInteraction):
return {}


@script_interface_register
class AngleCossquare(BondedInteraction):

"""
Expand Down Expand Up @@ -2542,6 +2556,7 @@ class AngleCossquare(BondedInteraction):
return {}


@script_interface_register
class IBM_Triel(BondedInteraction):

"""
Expand Down Expand Up @@ -2600,6 +2615,7 @@ class IBM_Triel(BondedInteraction):
"elasticLaw": self.elasticLaw}


@script_interface_register
class IBM_Tribend(BondedInteraction):

"""
Expand Down Expand Up @@ -2650,6 +2666,7 @@ class IBM_Tribend(BondedInteraction):
return {"kb": self.kb, "theta0": self.theta0}


@script_interface_register
class IBM_VolCons(BondedInteraction):

"""
Expand Down Expand Up @@ -2694,6 +2711,7 @@ class IBM_VolCons(BondedInteraction):
return immersed_boundaries.get_current_volume(self.softID)


@script_interface_register
class OifGlobalForces(BondedInteraction):

"""
Expand Down Expand Up @@ -2733,6 +2751,7 @@ class OifGlobalForces(BondedInteraction):
return {}


@script_interface_register
class OifLocalForces(BondedInteraction):

"""
Expand Down Expand Up @@ -2781,6 +2800,7 @@ class OifLocalForces(BondedInteraction):
return {}


@script_interface_register
class QuarticBond(BondedInteraction):

"""
Expand Down Expand Up @@ -2866,15 +2886,6 @@ class BondedInteractions(ScriptObjectRegistry):
_so_name = "Interactions::BondedInteractions"
_so_creation_policy = "GLOBAL"

def __init__(self, *args, **kwargs):
if args:
params, (_unpickle_so_class, (_so_name, bytestring)) = args
assert _so_name == self._so_name
self = _unpickle_so_class(_so_name, bytestring)
self.__setstate__(params)
else:
super().__init__(**kwargs)

def add(self, *args, **kwargs):
"""
Add a bond to the list.
Expand Down Expand Up @@ -2924,13 +2935,17 @@ class BondedInteractions(ScriptObjectRegistry):
raise ValueError(
"Index to BondedInteractions[] has to be an integer referring to a bond id")

if self.call_method('has_bond', bond_id=bond_id):
bond_obj = self.call_method('get_bond', bond_id=bond_id)
bond_obj._bond_id = bond_id
return bond_obj

# Find out the type of the interaction from ESPResSo
bond_type = get_bonded_interaction_type_from_es_core(bond_id)

# Check if the bonded interaction exists in ESPResSo core
if bond_type == BONDED_IA_NONE:
raise ValueError(
f"The bonded interaction with the id {bond_id} is not yet defined.")
raise ValueError(f"The bond with id {bond_id} is not yet defined.")

# Find the appropriate class representing such a bond
bond_class = bonded_interaction_classes[bond_type]
Expand Down Expand Up @@ -2986,7 +3001,7 @@ class BondedInteractions(ScriptObjectRegistry):
return bond_id

def __len__(self):
return bonded_ia_params_size()
return self.call_method('get_size')

# Support iteration over active bonded interactions
def __iter__(self):
Expand All @@ -2995,8 +3010,9 @@ class BondedInteractions(ScriptObjectRegistry):
yield self[bond_id]

def __reduce__(self):
so_reduce = super().__reduce__()
return (self.__class__, (self.__getstate__(), so_reduce))
so_callback, (so_name, so_bytestring) = super().__reduce__()
return (_restore_bonded_interactions,
(so_callback, (so_name, so_bytestring), self.__getstate__()))

def __getstate__(self):
params = {}
Expand All @@ -3012,3 +3028,9 @@ class BondedInteractions(ScriptObjectRegistry):
for bond_id, (bond_params, bond_type) in params.items():
self[bond_id] = bonded_interaction_classes[bond_type](
**bond_params)


def _restore_bonded_interactions(so_callback, so_callback_args, state):
so = so_callback(*so_callback_args)
so.__setstate__(state)
return so
2 changes: 1 addition & 1 deletion src/python/espressomd/particle_data.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,7 @@ cdef class ParticleHandle:
# Validity of the numeric id
if not bonded_ia_params_zero_based_type(bond[0]._bond_id):
raise ValueError(
f"The bond type f{bond[0]._bond_id} does not exist.")
f"The bond type {bond[0]._bond_id} does not exist.")

# Number of partners
expected_num_partners = bond[0].call_method('get_num_partners')
Expand Down
8 changes: 7 additions & 1 deletion src/python/espressomd/script_interface.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,13 @@ class ScriptInterfaceHelper(PScriptInterface):
if attr in self._valid_parameters():
self.set_params(**{attr: value})
else:
self.__dict__[attr] = value
super().__setattr__(attr, value)

def __delattr__(self, attr):
if attr in self._valid_parameters():
raise RuntimeError(f"Parameter '{attr}' is read-only")
else:
super().__delattr__(attr)

def generate_caller(self, method_name):
def template_method(**kwargs):
Expand Down
2 changes: 2 additions & 0 deletions src/script_interface/Context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class Context : public std::enable_shared_from_this<Context> {
*/
virtual boost::string_ref name(const ObjectHandle *o) const = 0;

virtual bool is_head_node() const = 0;

virtual ~Context() = default;
};
} // namespace ScriptInterface
Expand Down
3 changes: 2 additions & 1 deletion src/script_interface/ContextManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ std::string ContextManager::serialize(const ObjectHandle *o) const {

ContextManager::ContextManager(Communication::MpiCallbacks &callbacks,
const Utils::Factory<ObjectHandle> &factory) {
auto local_context = std::make_shared<LocalContext>(factory);
auto const mpi_rank = callbacks.comm().rank();
auto local_context = std::make_shared<LocalContext>(factory, mpi_rank);

/* If there is only one node, we can treat all objects as local, and thus
* never invoke any callback. */
Expand Down
7 changes: 6 additions & 1 deletion src/script_interface/GlobalContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class GlobalContext : public Context {

std::shared_ptr<LocalContext> m_node_local_context;

bool m_is_head_node;

private:
Communication::CallbackHandle<ObjectId, const std::string &,
const PackedMap &>
Expand All @@ -83,7 +85,8 @@ class GlobalContext : public Context {
public:
GlobalContext(Communication::MpiCallbacks &callbacks,
std::shared_ptr<LocalContext> node_local_context)
: m_node_local_context(std::move(node_local_context)),
: m_local_objects(), m_node_local_context(std::move(node_local_context)),
m_is_head_node(callbacks.comm().rank() == 0),
cb_make_handle(&callbacks,
[this](ObjectId id, const std::string &name,
const PackedMap &parameters) {
Expand Down Expand Up @@ -157,6 +160,8 @@ class GlobalContext : public Context {
make_shared(std::string const &name, const VariantMap &parameters) override;

boost::string_ref name(const ObjectHandle *o) const override;

bool is_head_node() const override { return m_is_head_node; };
};
} // namespace ScriptInterface

Expand Down
Loading

0 comments on commit f463f42

Please sign in to comment.