From 8032a8afd9d1a910b37113d1ea7d87f86b426f9d Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 2 Nov 2022 23:53:07 +0100 Subject: [PATCH] Tests: Add tests for calcfunction as `WorkChain` class member methods The test defines a workchain that defines a calcfunction as a class member in two ways: * A proper staticmethod * A class attribute Both versions work, but the former is more correct and is the only one that can be called from an instance of the workchain. The other has to be invoked by retrieving the calcfunction as an attribute and then calling it. The former is the only one that is offcially documented. The tests verify that both methods of declaring the calcfunction can be called within the `WorkChain` and that it also works when submitting the workchain to the daemon. A test is also added that verifies that caching of the calcfunctions functions properly. The changes also broke one existing test. The test took a calcfunction, originally defined on the module level, and modifying it within the test function scope. The goal of the test was twofold: * Check that changing the source code would be recognized by the caching mechanism and so an invocation of the changed function would not be cached from an invocation of the original * Check that it is possible to cache from a function defined inside the scope of a function. The changes to the dynamically built `FunctionProcess`, notably changing the type from `func.__name__` to `func.__qualname__` stopped the second point from working. In the original code, the type name would be simply `tests.engine.test_calcfunctions.add_function`, both for the module level function as well as the inlined function. However, with the change this becomes: `tests.engine.test_calcfunctions.TestCalcFunction.test_calcfunction_caching_change_code..add_calcfunction` this can no longer be loaded by the `ProcessNode.process_class` property and so `is_valid_cache` returns `False`, whereas in the original code it was a valid cache as the process class could be loaded. Arguably, the new code is more correct, but it is breaking. Before inlined functions were valid cache sources, but that is no longer the case. In exchange, class member functions are now valid cache sources where they weren't before. Arguably, it is preferable to support class member functions over inline functions. The broken test is fixed by moving the inlined `add_calcfunction` to a separate module such that it becomes a valid cache source again. --- aiida/orm/nodes/process/process.py | 6 +- docs/source/topics/processes/functions.rst | 33 ++++++++++ tests/engine/calcfunctions.py | 10 +++ tests/engine/test_calcfunctions.py | 22 ++++--- tests/engine/test_work_chain.py | 73 +++++++++++++++++++++- 5 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 tests/engine/calcfunctions.py diff --git a/aiida/orm/nodes/process/process.py b/aiida/orm/nodes/process/process.py index 40f584ac92..c677d2c74f 100644 --- a/aiida/orm/nodes/process/process.py +++ b/aiida/orm/nodes/process/process.py @@ -220,7 +220,7 @@ def process_class(self) -> Type['Process']: raise ValueError( f'could not load process class for entry point `{self.process_type}` for Node<{self.pk}>: {exception}' ) from exception - except ValueError: + except ValueError as exception: import importlib def str_rsplit_iter(string, sep='.'): @@ -239,8 +239,8 @@ def str_rsplit_iter(string, sep='.'): pass else: raise ValueError( - f'could not load process class from `{self.process_type}` for Node<{self.pk}>: {exception}' - ) + f'could not load process class from `{self.process_type}` for Node<{self.pk}>' + ) from exception return process_class diff --git a/docs/source/topics/processes/functions.rst b/docs/source/topics/processes/functions.rst index 25dfeb80d2..2741d2a6fe 100644 --- a/docs/source/topics/processes/functions.rst +++ b/docs/source/topics/processes/functions.rst @@ -212,6 +212,39 @@ The question you should ask yourself is whether a potential problem merits throw Or maybe, as in the example above, the problem is easily foreseeable and classifiable with a well defined exit status, in which case it might make more sense to return the exit code. At the end one should think which solution makes it easier for a workflow calling the function to respond based on the result and what makes it easier to query for these specific failure modes. +As class member methods +======================= + +.. versionadded:: 2.3 + +Process functions can also be declared as class member methods, for example as part of a :class:`~aiida.engine.processes.workchains.workchain.WorkChain`: + +.. code-block:: python + + class CalcFunctionWorkChain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('x') + spec.input('y') + spec.output('sum') + spec.outline( + cls.run_compute_sum, + ) + + @staticmethod + @calcfunction + def compute_sum(x, y): + return x + y + + def run_compute_sum(self): + self.out('sum', self.compute_sum(self.inputs.x, self.inputs.y)) + +In this example, the work chain declares a class method called ``compute_sum`` which is decorated with the ``calcfunction`` decorator to turn it into a calculation function. +It is important that the method is also decorated with the ``staticmethod`` (see the `Python documentation `_) such that the work chain instance is not passed when the method is invoked. +The calcfunction can be called from a work chain step like any other class method, as is shown in the last line. + Provenance ========== diff --git a/tests/engine/calcfunctions.py b/tests/engine/calcfunctions.py new file mode 100644 index 0000000000..771b1dadec --- /dev/null +++ b/tests/engine/calcfunctions.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +"""Definition of a calculation function used in ``test_calcfunctions.py``.""" +from aiida.engine import calcfunction +from aiida.orm import Int + + +@calcfunction +def add_calcfunction(data): + """Calcfunction mirroring a ``test_calcfunctions`` calcfunction but has a slightly different implementation.""" + return Int(data.value + 2) diff --git a/tests/engine/test_calcfunctions.py b/tests/engine/test_calcfunctions.py index 7a0d589a3b..4c679991ca 100644 --- a/tests/engine/test_calcfunctions.py +++ b/tests/engine/test_calcfunctions.py @@ -100,23 +100,27 @@ def test_calcfunction_caching(self): assert cached.base.links.get_incoming().one().node.uuid == input_node.uuid def test_calcfunction_caching_change_code(self): - """Verify that changing the source codde of a calcfunction invalidates any existing cached nodes.""" - result_original = self.test_calcfunction(self.default_int) + """Verify that changing the source code of a calcfunction invalidates any existing cached nodes. - # Intentionally using the same name, to check that caching anyway - # distinguishes between the calcfunctions. - @calcfunction - def add_calcfunction(data): # pylint: disable=redefined-outer-name - """This calcfunction has a different source code from the one created at the module level.""" - return Int(data.value + 2) + The ``add_calcfunction`` of the ``calcfunctions`` module uses the exact same name as the one defined in this + test module, however, it has a slightly different implementation. Note that we have to define the duplicate in + a different module, because we cannot define it in the same module (as the name clashes, on purpose) and we + cannot inline the calcfunction in this test, since inlined process functions are not valid cache sources. + """ + from .calcfunctions import add_calcfunction # pylint: disable=redefined-outer-name + + result_original = self.test_calcfunction(self.default_int) with enable_caching(identifier='*.add_calcfunction'): result_cached, cached = add_calcfunction.run_get_node(self.default_int) assert result_original != result_cached assert not cached.base.caching.is_created_from_cache + assert cached.is_valid_cache + # Test that the locally-created calcfunction can be cached in principle result2_cached, cached2 = add_calcfunction.run_get_node(self.default_int) - assert result_original != result2_cached + assert result2_cached != result_original + assert result2_cached == result_cached assert cached2.base.caching.is_created_from_cache def test_calcfunction_do_not_store_provenance(self): diff --git a/tests/engine/test_work_chain.py b/tests/engine/test_work_chain.py index 2262e4ed3c..779beab5e0 100644 --- a/tests/engine/test_work_chain.py +++ b/tests/engine/test_work_chain.py @@ -21,7 +21,7 @@ from aiida.common.utils import Capturing from aiida.engine import ExitCode, Process, ToContext, WorkChain, append_, calcfunction, if_, launch, return_, while_ from aiida.engine.persistence import ObjectLoader -from aiida.manage import get_manager +from aiida.manage import enable_caching, get_manager from aiida.orm import Bool, Float, Int, Str, load_node @@ -146,6 +146,36 @@ def _set_finished(self, function_name): self.finished_steps[function_name] = True +class CalcFunctionWorkChain(WorkChain): + + @classmethod + def define(cls, spec): + super().define(spec) + spec.input('a') + spec.input('b') + spec.output('out_member') + spec.output('out_static') + spec.outline( + cls.run_add_member, + cls.run_add_static, + ) + + @calcfunction + def add_member(a, b): # pylint: disable=no-self-argument + return a + b + + @staticmethod + @calcfunction + def add_static(a, b): + return a + b + + def run_add_member(self): + self.out('out_member', CalcFunctionWorkChain.add_member(self.inputs.a, self.inputs.b)) + + def run_add_static(self): + self.out('out_static', self.add_static(self.inputs.a, self.inputs.b)) + + class PotentialFailureWorkChain(WorkChain): """Work chain that can finish with a non-zero exit code.""" @@ -1031,6 +1061,47 @@ def _run_with_checkpoints(wf_class, inputs=None): proc = run_and_check_success(wf_class, **inputs) return proc.finished_steps + def test_member_calcfunction(self): + """Test defining a calcfunction as a ``WorkChain`` member method.""" + results, node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2)) + assert node.is_finished_ok + assert results['out_member'] == 3 + assert results['out_static'] == 3 + + @pytest.mark.usefixtures('aiida_profile_clean') + def test_member_calcfunction_caching(self): + """Test defining a calcfunction as a ``WorkChain`` member method with caching enabled.""" + results, node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2)) + assert node.is_finished_ok + assert results['out_member'] == 3 + assert results['out_static'] == 3 + + with enable_caching(): + results, cached_node = launch.run.get_node(CalcFunctionWorkChain, a=Int(1), b=Int(2)) + assert cached_node.is_finished_ok + assert results['out_member'] == 3 + assert results['out_static'] == 3 + + # Check that the calcfunctions called by the workchain have been cached + for called in cached_node.called: + assert called.base.caching.is_created_from_cache + assert called.base.caching.get_cache_source() in [n.uuid for n in node.called] + + def test_member_calcfunction_daemon(self, entry_points, daemon_client, submit_and_await): + """Test defining a calcfunction as a ``WorkChain`` member method submitted to the daemon.""" + entry_points.add(CalcFunctionWorkChain, 'aiida.workflows:testing.calcfunction.workchain') + + daemon_client.start_daemon() + + builder = CalcFunctionWorkChain.get_builder() + builder.a = Int(1) + builder.b = Int(2) + + node = submit_and_await(builder) + assert node.is_finished_ok + assert node.outputs.out_member == 3 + assert node.outputs.out_static == 3 + @pytest.mark.requires_rmq class TestWorkChainAbort: