diff --git a/aiida/orm/nodes/process/process.py b/aiida/orm/nodes/process/process.py index 40f584ac92..c677d2c74f 100644 --- a/aiida/orm/nodes/process/process.py +++ b/aiida/orm/nodes/process/process.py @@ -220,7 +220,7 @@ def process_class(self) -> Type['Process']: raise ValueError( f'could not load process class for entry point `{self.process_type}` for Node<{self.pk}>: {exception}' ) from exception - except ValueError: + except ValueError as exception: import importlib def str_rsplit_iter(string, sep='.'): @@ -239,8 +239,8 @@ def str_rsplit_iter(string, sep='.'): pass else: raise ValueError( - f'could not load process class from `{self.process_type}` for Node<{self.pk}>: {exception}' - ) + f'could not load process class from `{self.process_type}` for Node<{self.pk}>' + ) from exception return process_class diff --git a/docs/source/topics/processes/functions.rst b/docs/source/topics/processes/functions.rst index 25dfeb80d2..2741d2a6fe 100644 --- a/docs/source/topics/processes/functions.rst +++ b/docs/source/topics/processes/functions.rst @@ -212,6 +212,39 @@ The question you should ask yourself is whether a potential problem merits throw Or maybe, as in the example above, the problem is easily foreseeable and classifiable with a well defined exit status, in which case it might make more sense to return the exit code. At the end one should think which solution makes it easier for a workflow calling the function to respond based on the result and what makes it easier to query for these specific failure modes. +As class member methods +======================= + +.. versionadded:: 2.3 + +Process functions can also be declared as class member methods, for example as part of a :class:`~aiida.engine.processes.workchains.workchain.WorkChain`: + +.. code-block:: python + + class CalcFunctionWorkChain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('x') + spec.input('y') + spec.output('sum') + spec.outline( + cls.run_compute_sum, + ) + + @staticmethod + @calcfunction + def compute_sum(x, y): + return x + y + + def run_compute_sum(self): + self.out('sum', self.compute_sum(self.inputs.x, self.inputs.y)) + +In this example, the work chain declares a class method called ``compute_sum`` which is decorated with the ``calcfunction`` decorator to turn it into a calculation function. +It is important that the method is also decorated with the ``staticmethod`` (see the `Python documentation `_) such that the work chain instance is not passed when the method is invoked. +The calcfunction can be called from a work chain step like any other class method, as is shown in the last line. + Provenance ========== diff --git a/tests/engine/calcfunctions.py b/tests/engine/calcfunctions.py new file mode 100644 index 0000000000..771b1dadec --- /dev/null +++ b/tests/engine/calcfunctions.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +"""Definition of a calculation function used in ``test_calcfunctions.py``.""" +from aiida.engine import calcfunction +from aiida.orm import Int + + +@calcfunction +def add_calcfunction(data): + """Calcfunction mirroring a ``test_calcfunctions`` calcfunction but has a slightly different implementation.""" + return Int(data.value + 2) diff --git a/tests/engine/test_calcfunctions.py b/tests/engine/test_calcfunctions.py index 7a0d589a3b..4c679991ca 100644 --- a/tests/engine/test_calcfunctions.py +++ b/tests/engine/test_calcfunctions.py @@ -100,23 +100,27 @@ def test_calcfunction_caching(self): assert cached.base.links.get_incoming().one().node.uuid == input_node.uuid def test_calcfunction_caching_change_code(self): - """Verify that changing the source codde of a calcfunction invalidates any existing cached nodes.""" - result_original = self.test_calcfunction(self.default_int) + """Verify that changing the source code of a calcfunction invalidates any existing cached nodes. - # Intentionally using the same name, to check that caching anyway - # distinguishes between the calcfunctions. - @calcfunction - def add_calcfunction(data): # pylint: disable=redefined-outer-name - """This calcfunction has a different source code from the one created at the module level.""" - return Int(data.value + 2) + The ``add_calcfunction`` of the ``calcfunctions`` module uses the exact same name as the one defined in this + test module, however, it has a slightly different implementation. Note that we have to define the duplicate in + a different module, because we cannot define it in the same module (as the name clashes, on purpose) and we + cannot inline the calcfunction in this test, since inlined process functions are not valid cache sources. + """ + from .calcfunctions import add_calcfunction # pylint: disable=redefined-outer-name + + result_original = self.test_calcfunction(self.default_int) with enable_caching(identifier='*.add_calcfunction'): result_cached, cached = add_calcfunction.run_get_node(self.default_int) assert result_original != result_cached assert not cached.base.caching.is_created_from_cache + assert cached.is_valid_cache + # Test that the locally-created calcfunction can be cached in principle result2_cached, cached2 = add_calcfunction.run_get_node(self.default_int) - assert result_original != result2_cached + assert result2_cached != result_original + assert result2_cached == result_cached assert cached2.base.caching.is_created_from_cache def test_calcfunction_do_not_store_provenance(self): diff --git a/tests/engine/test_work_chain.py b/tests/engine/test_work_chain.py index 2262e4ed3c..779beab5e0 100644 --- a/tests/engine/test_work_chain.py +++ b/tests/engine/test_work_chain.py @@ -21,7 +21,7 @@ from aiida.common.utils import Capturing from aiida.engine import ExitCode, Process, ToContext, WorkChain, append_, calcfunction, if_, launch, return_, while_ from aiida.engine.persistence import ObjectLoader -from aiida.manage import get_manager +from aiida.manage import enable_caching, get_manager from aiida.orm import Bool, Float, Int, Str, load_node @@ -146,6 +146,36 @@ def _set_finished(self, function_name): self.finished_steps[function_name] = True +class CalcFunctionWorkChain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('a') + spec.input('b') + spec.output('out_member') + spec.output('out_static') + spec.outline( + cls.run_add_member, + cls.run_add_static, + ) + + @calcfunction + def add_member(a, b): # pylint: disable=no-self-argument + return a + b + + @staticmethod + @calcfunction + def add_static(a, b): + return a + b + + def run_add_member(self): + self.out('out_member', CalcFunctionWorkChain.add_member(self.inputs.a, self.inputs.b)) + + def run_add_static(self): + self.out('out_static', self.add_static(self.inputs.a, self.inputs.b)) + + class PotentialFailureWorkChain(WorkChain): """Work chain that can finish with a non-zero exit code.""" @@ -1031,6 +1061,47 @@ def _run_with_checkpoints(wf_class, inputs=None): proc = run_and_check_success(wf_class, **inputs) return proc.finished_steps + def test_member_calcfunction(self): + """Test defining a calcfunction as a ``WorkChain`` member method.""" + results, node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2)) + assert node.is_finished_ok + assert results['out_member'] == 3 + assert results['out_static'] == 3 + + @pytest.mark.usefixtures('aiida_profile_clean') + def test_member_calcfunction_caching(self): + """Test defining a calcfunction as a ``WorkChain`` member method with caching enabled.""" + results, node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2)) + assert node.is_finished_ok + assert results['out_member'] == 3 + assert results['out_static'] == 3 + + with enable_caching(): + results, cached_node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2)) + assert cached_node.is_finished_ok + assert results['out_member'] == 3 + assert results['out_static'] == 3 + + # Check that the calcfunctions called by the workchain have been cached + for called in cached_node.called: + assert called.base.caching.is_created_from_cache + assert called.base.caching.get_cache_source() in [n.uuid for n in node.called] + + def test_member_calcfunction_daemon(self, entry_points, daemon_client, submit_and_await): + """Test defining a calcfunction as a ``WorkChain`` member method submitted to the daemon.""" + entry_points.add(CalcFunctionWorkChain, 'aiida.workflows:testing.calcfunction.workchain') + + daemon_client.start_daemon() + + builder = CalcFunctionWorkChain.get_builder() + builder.a = Int(1) + builder.b = Int(2) + + node = submit_and_await(builder) + assert node.is_finished_ok + assert node.outputs.out_member == 3 + assert node.outputs.out_static == 3 + @pytest.mark.requires_rmq class TestWorkChainAbort: