Skip to content

Commit

Permalink
WorkChain: Protect public methods from being subclassed (#5779)
Browse files Browse the repository at this point in the history
The paradigm of the `WorkChain` requires a user in an implementation to
define the workflow logic through classmethods of the `WorkChain`
subclass. While this gives great flexibility and choice to the user,
there is a risk that a user inadvertently chooses a method name that
already exists on the `WorkChain` base class. Usually the `super` is not
called in this scenario and so the functionality is broken.

The typical example is where the user uses the `run` method as a step in
the outline of the `WorkChain`. The work chain will still run, however,
only that one step in the outline is called. Since the logic to continue
to the next step in the outline is defined in `WorkChain.run`, which is
overridden and now no longer called, the rest of the work chain is
skipped without any warning or error message, leaving the user
scratching their head as to what happened.

Here we protect this and other public methods on the `WorkChain` class
to prevent them from being overridden in subclasses. This is
accomplished by adding the `Protect` class as a metaclass. Since the
`WorkChain` already has the metaclass `plumpy.ProcessStateMachineMeta`,
which it inherits from its `Process` base class, and all metaclasses
need to share the same base, `Protect` also subclasses the
`ProcessStateMachineMeta` class.

The `Protect` class provides the `final` classmethod which can be used
to decorate a method in the `WorkChain` class that should be protected.
If a subclass implements it, as soon as the class is imported, a
`RuntimeError` is raised mentioning that the method cannot be
overridden.

The test `test_report_dbloghandler` had to be fixed because it actually
suffered from the very problem that is being fixed. It used the `run`
method to setup the test, but since the `check` was never being called,
the test always passed, even though the code `self._backend` in the
`check` is incorrect.
  • Loading branch information
sphuber authored Nov 28, 2022
1 parent 8db7fec commit 7929257
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 27 deletions.
98 changes: 81 additions & 17 deletions aiida/engine/processes/workchains/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Components for the WorkChain concept of the workflow engine."""
from __future__ import annotations

import collections.abc
import functools
import logging
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union
import typing as t

from plumpy.persistence import auto_persist
from plumpy.process_states import Continue, Wait
from plumpy.processes import ProcessStateMachineMeta
from plumpy.workchains import Stepper
from plumpy.workchains import WorkChainSpec as PlumpyWorkChainSpec
from plumpy.workchains import _PropagateReturn, if_, return_, while_
Expand All @@ -30,8 +33,8 @@
from ..process_spec import ProcessSpec
from .awaitable import Awaitable, AwaitableAction, AwaitableTarget, construct_awaitable

if TYPE_CHECKING:
from aiida.engine.runners import Runner
if t.TYPE_CHECKING:
from aiida.engine.runners import Runner # pylint: disable=unused-import

__all__ = ('WorkChain', 'if_', 'while_', 'return_')

Expand All @@ -40,8 +43,60 @@ class WorkChainSpec(ProcessSpec, PlumpyWorkChainSpec):
pass


class Protect(ProcessStateMachineMeta):
"""Metaclass that allows protecting class methods from being overridden by subclasses.
Usage as follows::
class SomeClass(metaclass=Protect):
@Protect.final
def private_method(self):
"This method cannot be overridden by a subclass."
If a subclass is imported that overrides the subclass, a ``RuntimeError`` is raised.
"""

__SENTINEL = object()

def __new__(cls, name, bases, namespace, **kwargs):
"""Collect all methods that were marked as protected and raise if the subclass defines it.
:raises RuntimeError: If the new class defines (i.e. overrides) a method that was decorated with ``final``.
"""
private = {
key for base in bases for key, value in vars(base).items() if callable(value) and cls.__is_final(value)
}
for key in namespace:
if key in private:
raise RuntimeError(f'the method `{key}` is protected cannot be overridden.')
return super().__new__(cls, name, bases, namespace, **kwargs)

@classmethod
def __is_final(cls, method) -> bool:
"""Return whether the method has been decorated by the ``final`` classmethod.
:return: Boolean, ``True`` if the method is marked as final, ``False`` otherwise.
"""
try:
return method.__final is cls.__SENTINEL # pylint: disable=protected-access
except AttributeError:
return False

@classmethod
def final(cls, method: t.Any):
"""Decorate a method with this method to protect it from being overridden.
Adds the ``__SENTINEL`` object as the ``__final`` private attribute to the given ``method`` and wraps it in
the ``typing.final`` decorator. The latter indicates to typing systems that it cannot be overridden in
subclasses.
"""
method.__final = cls.__SENTINEL # pylint: disable=protected-access,unused-private-member
return t.final(method)


@auto_persist('_awaitables')
class WorkChain(Process):
class WorkChain(Process, metaclass=Protect):
"""The `WorkChain` class is the principle component to implement workflows in AiiDA."""

_node_class = WorkChainNode
Expand All @@ -51,9 +106,9 @@ class WorkChain(Process):

def __init__(
self,
inputs: Optional[dict] = None,
logger: Optional[logging.Logger] = None,
runner: Optional['Runner'] = None,
inputs: dict | None = None,
logger: logging.Logger | None = None,
runner: 'Runner' | None = None,
enable_persistence: bool = True
) -> None:
"""Construct a WorkChain instance.
Expand All @@ -71,8 +126,8 @@ def __init__(

super().__init__(inputs, logger, runner, enable_persistence=enable_persistence)

self._stepper: Optional[Stepper] = None
self._awaitables: List[Awaitable] = []
self._stepper: Stepper | None = None
self._awaitables: list[Awaitable] = []
self._context = AttributeDict()

@classmethod
Expand Down Expand Up @@ -119,11 +174,12 @@ def load_instance_state(self, saved_state, load_context):
if self._awaitables:
self.action_awaitables()

@Protect.final
def on_run(self):
super().on_run()
self.node.set_stepper_state_info(str(self._stepper))

def _resolve_nested_context(self, key: str) -> Tuple[AttributeDict, str]:
def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]:
"""
Returns a reference to a sub-dictionary of the context and the last key,
after resolving a potentially segmented key where required sub-dictionaries are created as needed.
Expand Down Expand Up @@ -155,6 +211,7 @@ def _resolve_nested_context(self, key: str) -> Tuple[AttributeDict, str]:

return ctx, ctx_path[-1]

@Protect.final
def insert_awaitable(self, awaitable: Awaitable) -> None:
"""Insert an awaitable that should be terminated before before continuing to the next step.
Expand All @@ -178,7 +235,8 @@ def insert_awaitable(self, awaitable: Awaitable) -> None:
) # add only if everything went ok, otherwise we end up in an inconsistent state
self._update_process_status()

def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:
@Protect.final
def resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None:
"""Resolve an awaitable.
Precondition: must be an awaitable that was previously inserted.
Expand Down Expand Up @@ -210,7 +268,8 @@ def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:
# then we should not try to update it
self._update_process_status()

def to_context(self, **kwargs: Union[Awaitable, ProcessNode]) -> None:
@Protect.final
def to_context(self, **kwargs: Awaitable | ProcessNode) -> None:
"""Add a dictionary of awaitables to the context.
This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will
Expand All @@ -230,11 +289,12 @@ def _update_process_status(self) -> None:
self.node.set_process_status(None)

@override
def run(self) -> Any:
@Protect.final
def run(self) -> t.Any:
self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type]
return self._do_step()

def _do_step(self) -> Any:
def _do_step(self) -> t.Any:
"""Execute the next step in the outline and return the result.
If the stepper returns a non-finished status and the return value is of type ToContext, the contents of the
Expand All @@ -245,7 +305,7 @@ def _do_step(self) -> Any:
from .context import ToContext

self._awaitables = []
result: Any = None
result: t.Any = None

try:
assert self._stepper is not None
Expand Down Expand Up @@ -273,7 +333,7 @@ def _do_step(self) -> Any:

return Continue(self._do_step)

def _store_nodes(self, data: Any) -> None:
def _store_nodes(self, data: t.Any) -> None:
"""Recurse through a data structure and store any unstored nodes that are found along the way
:param data: a data structure potentially containing unstored nodes
Expand All @@ -288,6 +348,7 @@ def _store_nodes(self, data: Any) -> None:
self._store_nodes(value)

@override
@Protect.final
def on_exiting(self) -> None:
"""Ensure that any unstored nodes in the context are stored, before the state is exited
Expand All @@ -301,14 +362,16 @@ def on_exiting(self) -> None:
# An uncaught exception here will have bizarre and disastrous consequences
self.logger.exception('exception in _store_nodes called in on_exiting')

def on_wait(self, awaitables: Sequence[Awaitable]):
@Protect.final
def on_wait(self, awaitables: t.Sequence[Awaitable]):
"""Entering the WAITING state."""
super().on_wait(awaitables)
if self._awaitables:
self.action_awaitables()
else:
self.call_soon(self.resume)

@Protect.final
def action_awaitables(self) -> None:
"""Handle the awaitables that are currently registered with the work chain.
Expand All @@ -323,6 +386,7 @@ def action_awaitables(self) -> None:
else:
assert f"invalid awaitable target '{awaitable.target}'"

@Protect.final
def on_process_finished(self, awaitable: Awaitable) -> None:
"""Callback function called by the runner when the process instance identified by pk is completed.
Expand Down
1 change: 1 addition & 0 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ py:class plumpy.utils.AttributesDict
py:class plumpy.process_states.State
py:class plumpy.workchains._If
py:class plumpy.workchains._While
py:class plumpy.processes.ProcessStateMachineMeta
py:class PersistenceError
py:class State
py:class Stepper
Expand Down
32 changes: 22 additions & 10 deletions tests/engine/test_work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,16 +731,15 @@ class TestWorkChain(WorkChain):
@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.run, cls.check)
spec.outline(cls.emit_report, cls.check)
spec.outputs.dynamic = True

def run(self):
orm.Log.collection.delete_all()
def emit_report(self):
self.report('Testing the report function')

def check(self):
logs = self._backend.logs.find()
assert len(logs) == 1
messages = [log.message for log in orm.Log.collection.get_logs_for(self.node)]
assert any('Testing the report function' in message for message in messages)

run_and_check_success(TestWorkChain)

Expand Down Expand Up @@ -996,12 +995,9 @@ class ExitCodeWorkChain(WorkChain):
@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.run)
spec.outline()
spec.exit_code(status, label, message)

def run(self):
pass

wc = ExitCodeWorkChain()

# The exit code can be gotten by calling it with the status or label, as well as using attribute dereferencing
Expand Down Expand Up @@ -1600,7 +1596,7 @@ def define(cls, spec):
super().define(spec)
spec.input('a', valid_type=Bool, default=lambda: Bool(True))

def run(self):
def step(self):
pass

def test_unique_default_inputs(self):
Expand All @@ -1623,3 +1619,19 @@ def test_unique_default_inputs(self):
# as both `child_one.a` and `child_two.a` should have the same UUID.
node = load_node(uuid=node.base.links.get_incoming().get_node_by_label('child_one__a').uuid)
assert len(uuids) == len(nodes), f'Only {len(uuids)} unique UUIDS for {len(nodes)} input nodes'


def test_illegal_override_run():
"""Test that overriding a protected workchain method raises a ``RuntimeError``."""
with pytest.raises(RuntimeError, match='the method `run` is protected cannot be overridden.'):

class IllegalWorkChain(WorkChain): # pylint: disable=unused-variable
"""Work chain that illegally overrides the ``run`` method."""

@classmethod
def define(cls, spec):
super().define(spec)
spec.outline(cls.run)

def run(self):
pass

0 comments on commit 7929257

Please sign in to comment.