Skip to content

Commit

Permalink
Merge pull request #1853 from astatt/fix_ext_plugin
Browse files Browse the repository at this point in the history
Add _ext_module variable to base python classes
  • Loading branch information
joaander committed Jul 22, 2024
2 parents 14ffc6c + 93a67cd commit 3ac3ed3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
8 changes: 6 additions & 2 deletions hoomd/md/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,16 @@ class Constraint(Force):
for `isinstance` or `issubclass` checks.
"""

# Module where the C++ class is defined. Reassign this when developing an
# external plugin.
_ext_module = _md

def _attach_hook(self):
"""Create the c++ mirror class."""
if isinstance(self._simulation.device, hoomd.device.CPU):
cpp_cls = getattr(_md, self._cpp_class_name)
cpp_cls = getattr(self._ext_module, self._cpp_class_name)
else:
cpp_cls = getattr(_md, self._cpp_class_name + "GPU")
cpp_cls = getattr(self._ext_module, self._cpp_class_name + "GPU")

self._cpp_obj = cpp_cls(self._simulation.state._cpp_sys_def)

Expand Down
8 changes: 6 additions & 2 deletions hoomd/md/external/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ class Field(force.Force):
for `isinstance` or `issubclass` checks.
"""

# Module where the C++ class is defined. Reassign this when developing an
# external plugin.
_ext_module = _md

def _attach_hook(self):
if isinstance(self._simulation.device, hoomd.device.CPU):
cls = getattr(_md, self._cpp_class_name)
cls = getattr(self._ext_module, self._cpp_class_name)
else:
cls = getattr(_md, self._cpp_class_name + "GPU")
cls = getattr(self._ext_module, self._cpp_class_name + "GPU")

self._cpp_obj = cls(self._simulation.state._cpp_sys_def)

Expand Down
8 changes: 6 additions & 2 deletions hoomd/md/external/wall.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,19 @@ class WallPotential(force.Force):
potentials.
"""

# Module where the C++ class is defined. Reassign this when developing an
# external plugin.
_ext_module = _md

def __init__(self, walls):
self._walls = None
self.walls = hoomd.wall._WallsMetaList(walls, _to_md_cpp_wall)

def _attach_hook(self):
if isinstance(self._simulation.device, hoomd.device.CPU):
cls = getattr(_md, self._cpp_class_name)
cls = getattr(self._ext_module, self._cpp_class_name)
else:
cls = getattr(_md, self._cpp_class_name + "GPU")
cls = getattr(self._ext_module, self._cpp_class_name + "GPU")
self._cpp_obj = cls(self._simulation.state._cpp_sys_def)
self._walls._sync({
hoomd.wall.Sphere:
Expand Down

0 comments on commit 3ac3ed3

Please sign in to comment.