Skip to content

Commit

Permalink
script_interface: Add automatic feature checks
Browse files Browse the repository at this point in the history
  • Loading branch information
jngrad committed Jul 10, 2023
1 parent 780358f commit 866e6aa
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 102 deletions.
27 changes: 4 additions & 23 deletions src/python/espressomd/electrokinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from .script_interface import ScriptInterfaceHelper, script_interface_register, ScriptObjectList, array_variant
import espressomd.detail.walberla
import espressomd.shapes
import espressomd.code_features


@script_interface_register
Expand All @@ -34,16 +33,10 @@ class EKFFT(ScriptInterfaceHelper):
A FFT-based Poisson solver.
"""

_so_name = "walberla::EKFFT"
_so_features = ("WALBERLA_FFT",)
_so_creation_policy = "GLOBAL"

def __init__(self, *args, **kwargs):
if not espressomd.code_features.has_features("WALBERLA_FFT"):
raise NotImplementedError("Feature WALBERLA not compiled in")

super().__init__(*args, **kwargs)


@script_interface_register
class EKNone(ScriptInterfaceHelper):
Expand All @@ -53,24 +46,14 @@ class EKNone(ScriptInterfaceHelper):
"""
_so_name = "walberla::EKNone"
_so_features = ("WALBERLA",)
_so_creation_policy = "GLOBAL"

def __init__(self, *args, **kwargs):
if not espressomd.code_features.has_features("WALBERLA"):
raise NotImplementedError("Feature WALBERLA not compiled in")

super().__init__(*args, **kwargs)


@script_interface_register
class EKContainer(ScriptObjectList):
_so_name = "walberla::EKContainer"

def __init__(self, *args, **kwargs):
if not espressomd.code_features.has_features("WALBERLA"):
raise NotImplementedError("Feature WALBERLA not compiled in")

super().__init__(*args, **kwargs)
_so_features = ("WALBERLA",)

def add(self, ekspecies):
self.call_method("add", object=ekspecies)
Expand Down Expand Up @@ -165,6 +148,7 @@ class EKSpecies(ScriptInterfaceHelper,
"""

_so_name = "walberla::EKSpecies"
_so_features = ("WALBERLA",)
_so_creation_policy = "GLOBAL"
_so_bind_methods = (
"clear_density_boundaries",
Expand All @@ -176,9 +160,6 @@ class EKSpecies(ScriptInterfaceHelper,
)

def __init__(self, *args, **kwargs):
if not espressomd.code_features.has_features("WALBERLA"):
raise NotImplementedError("Feature WALBERLA not compiled in")

if "sip" not in kwargs:
params = self.default_params()
params.update(kwargs)
Expand Down
44 changes: 7 additions & 37 deletions src/python/espressomd/electrostatics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,14 @@

from . import utils
from .script_interface import ScriptInterfaceHelper, script_interface_register
from .code_features import has_features


@script_interface_register
class Container(ScriptInterfaceHelper):
_so_name = "Coulomb::Container"
_so_features = ("ELECTROSTATICS",)
_so_bind_methods = ("clear",)

def __init__(self, *args, **kwargs):
if not has_features("ELECTROSTATICS"):
raise NotImplementedError("Feature ELECTROSTATICS not compiled in")

super().__init__(*args, **kwargs)


class ElectrostaticInteraction(ScriptInterfaceHelper):
"""
Expand All @@ -45,10 +39,9 @@ class ElectrostaticInteraction(ScriptInterfaceHelper):
"""
_so_creation_policy = "GLOBAL"
_so_features = ("ELECTROSTATICS",)

def __init__(self, **kwargs):
self._check_required_features()

if 'sip' not in kwargs:
for key in self.required_keys():
if key not in kwargs:
Expand All @@ -64,10 +57,6 @@ def __init__(self, **kwargs):
else:
super().__init__(**kwargs)

def _check_required_features(self):
if not has_features("ELECTROSTATICS"):
raise NotImplementedError("Feature ELECTROSTATICS not compiled in")

def validate_params(self, params):
"""Check validity of given parameters.
"""
Expand Down Expand Up @@ -243,10 +232,7 @@ class P3M(_P3MBase):
"""
_so_name = "Coulomb::CoulombP3M"
_so_creation_policy = "GLOBAL"

def _check_required_features(self):
if not has_features("P3M"):
raise NotImplementedError("Feature P3M not compiled in")
_so_features = ("P3M",)


@script_interface_register
Expand Down Expand Up @@ -296,12 +282,7 @@ class P3MGPU(_P3MBase):
"""
_so_name = "Coulomb::CoulombP3MGPU"
_so_creation_policy = "GLOBAL"

def _check_required_features(self):
if not has_features("P3M"):
raise NotImplementedError("Feature P3M not compiled in")
if not has_features("CUDA"):
raise NotImplementedError("Feature CUDA not compiled in")
_so_features = ("P3M", "CUDA")


@script_interface_register
Expand Down Expand Up @@ -360,10 +341,7 @@ class ELC(ElectrostaticInteraction):
"""
_so_name = "Coulomb::ElectrostaticLayerCorrection"
_so_creation_policy = "GLOBAL"

def _check_required_features(self):
if not has_features("P3M"):
raise NotImplementedError("Feature P3M not compiled in")
_so_features = ("P3M",)

def validate_params(self, params):
utils.check_type_or_throw_except(
Expand Down Expand Up @@ -447,10 +425,7 @@ class MMM1DGPU(ElectrostaticInteraction):
"""
_so_name = "Coulomb::CoulombMMM1DGpu"
_so_creation_policy = "GLOBAL"

def _check_required_features(self):
if not has_features("MMM1D_GPU"):
raise NotImplementedError("Feature MMM1D_GPU not compiled in")
_so_features = ("MMM1D_GPU",)

def default_params(self):
return {"far_switch_radius": -1.,
Expand Down Expand Up @@ -505,17 +480,12 @@ class Scafacos(ElectrostaticInteraction):
"""
_so_name = "Coulomb::CoulombScafacos"
_so_creation_policy = "GLOBAL"
_so_features = ("ELECTROSTATICS", "SCAFACOS")
_so_bind_methods = ElectrostaticInteraction._so_bind_methods + \
("get_available_methods",
"get_near_field_delegation",
"set_near_field_delegation")

def _check_required_features(self):
if not has_features("ELECTROSTATICS"):
raise NotImplementedError("Feature ELECTROSTATICS not compiled in")
if not has_features("SCAFACOS"):
raise NotImplementedError("Feature SCAFACOS not compiled in")

def default_params(self):
return {"check_neutrality": True}

Expand Down
9 changes: 2 additions & 7 deletions src/python/espressomd/lb.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ class LBFluidWalberla(HydrodynamicInteraction,
"""

_so_name = "walberla::LBFluid"
_so_features = ("WALBERLA",)
_so_creation_policy = "GLOBAL"
_so_bind_methods = (
"add_force_at_pos",
Expand All @@ -237,9 +238,6 @@ class LBFluidWalberla(HydrodynamicInteraction,
)

def __init__(self, *args, **kwargs):
if not espressomd.code_features.has_features("WALBERLA"):
raise NotImplementedError("Feature WALBERLA not compiled in")

if "sip" not in kwargs:
params = self.default_params()
params.update(kwargs)
Expand Down Expand Up @@ -329,13 +327,10 @@ class LBFluidWalberlaGPU(HydrodynamicInteraction):
list of parameters.
"""
_so_features = ("WALBERLA", "CUDA")

# pylint: disable=unused-argument
def __init__(self, *args, **kwargs):
if not espressomd.code_features.has_features("CUDA"):
raise NotImplementedError("Feature CUDA not compiled in")
if not espressomd.code_features.has_features("WALBERLA"):
raise NotImplementedError("Feature WALBERLA not compiled in")
raise NotImplementedError("Not implemented yet")


Expand Down
40 changes: 6 additions & 34 deletions src/python/espressomd/magnetostatics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,14 @@

from . import utils
from .script_interface import ScriptInterfaceHelper, script_interface_register
from .code_features import has_features


@script_interface_register
class Container(ScriptInterfaceHelper):
_so_name = "Dipoles::Container"
_so_features = ("DIPOLES",)
_so_bind_methods = ("clear",)

def __init__(self, *args, **kwargs):
if not has_features("DIPOLES"):
raise NotImplementedError("Feature DIPOLES not compiled in")

super().__init__(*args, **kwargs)


class MagnetostaticInteraction(ScriptInterfaceHelper):
"""
Expand All @@ -45,10 +39,9 @@ class MagnetostaticInteraction(ScriptInterfaceHelper):
"""
_so_creation_policy = "GLOBAL"
_so_features = ("DIPOLES",)

def __init__(self, **kwargs):
self._check_required_features()

if 'sip' not in kwargs:
for key in self.required_keys():
if key not in kwargs:
Expand All @@ -65,10 +58,6 @@ def __init__(self, **kwargs):
else:
super().__init__(**kwargs)

def _check_required_features(self):
if not has_features("DIPOLES"):
raise NotImplementedError("Feature DIPOLES not compiled in")

def validate_params(self, params):
"""Check validity of given parameters.
"""
Expand Down Expand Up @@ -128,10 +117,7 @@ class DipolarP3M(MagnetostaticInteraction):
"""
_so_name = "Dipoles::DipolarP3M"

def _check_required_features(self):
if not has_features("DP3M"):
raise NotImplementedError("Feature DP3M not compiled in")
_so_features = ("DP3M",)

def validate_params(self, params):
"""Check validity of parameters.
Expand Down Expand Up @@ -234,16 +220,10 @@ class Scafacos(MagnetostaticInteraction):
"""
_so_name = "Dipoles::DipolarScafacos"
_so_creation_policy = "GLOBAL"
_so_features = ("DIPOLES", "SCAFACOS_DIPOLES")
_so_bind_methods = MagnetostaticInteraction._so_bind_methods + \
("get_available_methods", )

def _check_required_features(self):
if not has_features("DIPOLES"):
raise NotImplementedError("Feature DIPOLES not compiled in")
if not has_features("SCAFACOS_DIPOLES"):
raise NotImplementedError(
"Feature SCAFACOS_DIPOLES not compiled in")

def default_params(self):
return {}

Expand Down Expand Up @@ -274,11 +254,7 @@ class DipolarDirectSumGpu(MagnetostaticInteraction):
"""
_so_name = "Dipoles::DipolarDirectSumGpu"
_so_creation_policy = "GLOBAL"

def _check_required_features(self):
if not has_features("DIPOLAR_DIRECT_SUM"):
raise NotImplementedError(
"Features CUDA and DIPOLES not compiled in")
_so_features = ("DIPOLAR_DIRECT_SUM", "CUDA")

def default_params(self):
return {}
Expand Down Expand Up @@ -310,11 +286,7 @@ class DipolarBarnesHutGpu(MagnetostaticInteraction):
"""
_so_name = "Dipoles::DipolarBarnesHutGpu"
_so_creation_policy = "GLOBAL"

def _check_required_features(self):
if not has_features("DIPOLAR_BARNES_HUT"):
raise NotImplementedError(
"Features CUDA and DIPOLES not compiled in")
_so_features = ("DIPOLAR_BARNES_HUT", "CUDA")

def default_params(self):
return {"epssq": 100.0, "itolsq": 4.0}
Expand Down
3 changes: 3 additions & 0 deletions src/python/espressomd/script_interface.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,7 @@ cdef extern from "script_interface/initialize.hpp" namespace "ScriptInterface":
cdef extern from "script_interface/get_value.hpp" namespace "ScriptInterface":
T get_value[T](const Variant T)

cdef extern from "script_interface/code_info/CodeInfo.hpp" namespace "ScriptInterface::CodeInfo":
void check_features(const vector[string] & features) except +

cdef void init(MpiCallbacks &)
6 changes: 6 additions & 0 deletions src/python/espressomd/script_interface.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,16 @@ def _unpickle_so_class(so_name, state):

class ScriptInterfaceHelper(PScriptInterface):
_so_name = None
_so_features = ()
_so_bind_methods = ()
_so_creation_policy = "GLOBAL"

def __init__(self, **kwargs):
cdef vector[string] features_vec
if self._so_features:
for feature in self._so_features:
features_vec.push_back(utils.to_char_pointer(feature))
check_features(features_vec)
super().__init__(self._so_name, policy=self._so_creation_policy,
**kwargs)
self.define_bound_methods()
Expand Down
25 changes: 25 additions & 0 deletions src/script_interface/code_info/CodeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@
#include "config/version.hpp"
#include "script_interface/scafacos/scafacos.hpp"

#include <boost/algorithm/string/join.hpp>

#include <algorithm>
#include <string>
#include <vector>

namespace ScriptInterface {
namespace CodeInfo {

static auto get_feature_vector(char const *const ptr[], unsigned int len) {
return std::vector<std::string>{ptr, ptr + len};
}

static Variant get_feature_list(char const *const ptr[], unsigned int len) {
return make_vector_of_variants(std::vector<std::string>{ptr, ptr + len});
}
Expand All @@ -54,5 +61,23 @@ Variant CodeInfo::do_call_method(std::string const &name,
return {};
}

void check_features(std::vector<std::string> const &features) {
auto const allowed = get_feature_vector(FEATURES_ALL, NUM_FEATURES_ALL);
auto const built = get_feature_vector(FEATURES, NUM_FEATURES);
std::vector<std::string> missing_features{};
for (auto const &feature : features) {
if (std::find(allowed.begin(), allowed.end(), feature) == allowed.end()) {
throw std::runtime_error("Unknown feature '" + feature + "'");
}
if (std::find(built.begin(), built.end(), feature) == built.end()) {
missing_features.emplace_back(feature);
}
}
if (not missing_features.empty()) {
throw std::runtime_error("Missing features " +
boost::algorithm::join(missing_features, ", "));
}
}

} // namespace CodeInfo
} // namespace ScriptInterface
Loading

0 comments on commit 866e6aa

Please sign in to comment.