From 2042324cbac9f45765c993cb7851fb3fe460508a Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 21 Oct 2021 19:48:55 +0200 Subject: [PATCH] `EosWorkChain`: fix bug in `get_scale_factors` (#226) The return type of the `get_scale_factors` method was inconsistent. If `scale_factors` is defined in the input, it would return a `List` node of floats, but otherwise it would return a `list` of `Float` nodes. The method is changed to always return a tuple of normal floats. The caller then has the responsability of casting to a `Float` node if necessary. The input spec of the `EosWorkChain` is also updated to use a serializer for the `scale_factors`, `scale_count` and `scale_increment` inputs. This allows a user to pass a simple base type and it will automatically be converted to the corresponding AiiDA data node type. Note that this feature does not yet work for `List` but this will be added soon to `aiida-core`. Note that the test needs a special condition to transform `list` inputs into `orm.List` nodes. The reason for this special treatment is that for `aiida-core<2.0` there is no automatic serializer yet for `list` types. Once we upgrade the requirement, we can drop this special case. --- aiida_common_workflows/workflows/eos.py | 18 ++-- tests/conftest.py | 24 ++++++ tests/workflows/eos/test_workchain_eos.py | 100 ++++++++++++++-------- 3 files changed, 97 insertions(+), 45 deletions(-) diff --git a/aiida_common_workflows/workflows/eos.py b/aiida_common_workflows/workflows/eos.py index 86a702bd..ae2c511b 100644 --- a/aiida_common_workflows/workflows/eos.py +++ b/aiida_common_workflows/workflows/eos.py @@ -81,12 +81,14 @@ def define(cls, spec): # yapf: disable super().define(spec) spec.input('structure', valid_type=orm.StructureData, help='The structure at equilibrium volume.') - spec.input('scale_factors', valid_type=orm.List, required=False, validator=validate_scale_factors, + spec.input('scale_factors', valid_type=orm.List, required=False, + validator=validate_scale_factors, serializer=orm.to_aiida_type, help='The list of scale factors at which the volume and total energy of the structure should be computed.') - spec.input('scale_count', valid_type=orm.Int, default=lambda: orm.Int(7), validator=validate_scale_count, + spec.input('scale_count', valid_type=orm.Int, default=lambda: orm.Int(7), + validator=validate_scale_count, serializer=orm.to_aiida_type, help='The number of points to compute for the equation of state.') spec.input('scale_increment', valid_type=orm.Float, default=lambda: orm.Float(0.02), - validator=validate_scale_increment, + validator=validate_scale_increment, serializer=orm.to_aiida_type, help='The relative difference between consecutive scaling factors.') spec.input_namespace('generator_inputs', help='The inputs that will be passed to the input generator of the specified `sub_process`.') @@ -127,11 +129,11 @@ def define(cls, spec): def get_scale_factors(self): """Return the list of scale factors.""" if 'scale_factors' in self.inputs: - return self.inputs.scale_factors + return tuple(self.inputs.scale_factors) count = self.inputs.scale_count.value increment = self.inputs.scale_increment.value - return [orm.Float(1 + i * increment - (count - 1) * increment / 2) for i in range(count)] + return tuple(float(1 + i * increment - (count - 1) * increment / 2) for i in range(count)) def get_sub_workchain_builder(self, scale_factor, reference_workchain=None): """Return the builder for the relax workchain.""" @@ -149,7 +151,7 @@ def get_sub_workchain_builder(self, scale_factor, reference_workchain=None): def run_init(self): """Run the first workchain.""" - scale_factor = self.get_scale_factors()[0] + scale_factor = orm.Float(self.get_scale_factors()[0]) builder, structure = self.get_sub_workchain_builder(scale_factor) self.report(f'submitting `{builder.process_class.__name__}` for scale_factor `{scale_factor}`') self.ctx.reference_workchain = self.submit(builder) @@ -166,7 +168,9 @@ def run_eos(self): """Run the sub process at each scale factor to compute the structure volume and total energy.""" for scale_factor in self.get_scale_factors()[1:]: reference_workchain = self.ctx.reference_workchain - builder, structure = self.get_sub_workchain_builder(scale_factor, reference_workchain=reference_workchain) + builder, structure = self.get_sub_workchain_builder( + orm.Float(scale_factor), reference_workchain=reference_workchain + ) self.report(f'submitting `{builder.process_class.__name__}` for scale_factor `{scale_factor}`') self.ctx.structures.append(structure) self.to_context(children=append_(self.submit(builder))) diff --git a/tests/conftest.py b/tests/conftest.py index 7feef3c5..49e7b569 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -106,6 +106,30 @@ def _generate_code(entry_point): return _generate_code +@pytest.fixture +def generate_workchain(): + """Generate an instance of a ``WorkChain``.""" + + def _generate_workchain(entry_point, inputs): + """Generate an instance of a ``WorkChain`` with the given entry point and inputs. + + :param entry_point: entry point name of the work chain subclass. + :param inputs: inputs to be passed to process construction. + :return: a ``WorkChain`` instance. + """ + from aiida.engine.utils import instantiate_process + from aiida.manage.manager import get_manager + from aiida.plugins import WorkflowFactory + + process_class = WorkflowFactory(entry_point) + runner = get_manager().get_runner() + process = instantiate_process(runner, process_class, **inputs) + + return process + + return _generate_workchain + + @pytest.fixture def generate_eos_node(generate_structure): """Generate an instance of ``EquationOfStateWorkChain``.""" diff --git a/tests/workflows/eos/test_workchain_eos.py b/tests/workflows/eos/test_workchain_eos.py index f90475fc..511be030 100644 --- a/tests/workflows/eos/test_workchain_eos.py +++ b/tests/workflows/eos/test_workchain_eos.py @@ -27,6 +27,35 @@ def common_relax_workchain(request) -> CommonRelaxWorkChain: return WorkflowFactory(request.param) +@pytest.fixture +@pytest.mark.usefixtures('aiida_profile') +def generate_eos_inputs(generate_structure, generate_code): + """Return a dictionary of defaults inputs for the ``EquationOfStateWorkChain``.""" + + def _generate_eos_inputs(): + return { + 'structure': generate_structure(symbols=('Si',)), + 'sub_process_class': 'common_workflows.relax.quantum_espresso', + 'generator_inputs': { + 'protocol': 'fast', + 'engines': { + 'relax': { + 'code': generate_code('quantumespresso.pw').store(), + 'options': { + 'resources': { + 'num_machines': 1 + } + } + } + }, + 'electronic_type': 'metal', + 'relax_type': 'positions' + } + } + + return _generate_eos_inputs + + def test_validate_sub_process_class(ctx): """Test the `validate_sub_process_class` validator.""" for value in [None, WorkChain]: @@ -41,25 +70,9 @@ def test_validate_sub_process_class_plugins(ctx, common_relax_workchain): @pytest.mark.usefixtures('sssp') -def test_validate_inputs_scale(ctx, generate_code, generate_structure): +def test_validate_inputs_scale(ctx, generate_eos_inputs): """Test the ``validate_inputs`` validator for invalid scale inputs.""" - base_values = { - 'structure': generate_structure(symbols=('Si',)), - 'sub_process_class': 'common_workflows.relax.quantum_espresso', - 'generator_inputs': { - 'engines': { - 'relax': { - 'code': generate_code('quantumespresso.pw'), - 'options': { - 'resources': { - 'num_machines': 1 - } - } - } - }, - 'electronic_type': 'metal' - } - } + base_values = generate_eos_inputs() value = copy.deepcopy(base_values) assert eos.validate_inputs( @@ -88,27 +101,10 @@ def test_validate_inputs_scale(ctx, generate_code, generate_structure): @pytest.mark.usefixtures('sssp') -def test_validate_inputs_generator_inputs(ctx, generate_code, generate_structure): +def test_validate_inputs_generator_inputs(ctx, generate_eos_inputs): """Test the ``validate_inputs`` validator for invalid generator inputs.""" - value = { - 'scale_factors': [], - 'structure': generate_structure(symbols=('Si',)), - 'sub_process_class': 'common_workflows.relax.quantum_espresso', - 'generator_inputs': { - 'engines': { - 'relax': { - 'code': generate_code('quantumespresso.pw'), - 'options': { - 'resources': { - 'num_machines': 1 - } - } - } - }, - 'electronic_type': 'metal' - } - } - + value = generate_eos_inputs() + value['scale_factors'] = [] assert eos.validate_inputs(value, ctx) is None value['generator_inputs']['electronic_type'] = 'invalid_value' @@ -145,3 +141,31 @@ def test_validate_relax_type(ctx): assert eos.validate_relax_type( RelaxType.CELL, ctx ) == '`generator_inputs.relax_type`. Equation of state and relaxation with variable volume not compatible.' + + +@pytest.mark.parametrize( + 'scaling_inputs, expected', ( + ({ + 'scale_factors': [0.98, 1.0, 1.02] + }, (0.98, 1.0, 1.02)), + ({ + 'scale_count': 3, + 'scale_increment': 0.02 + }, (0.98, 1.0, 1.02)), + ) +) +@pytest.mark.usefixtures('sssp') +def test_get_scale_factors(generate_workchain, generate_eos_inputs, scaling_inputs, expected): + """Test the ``EquationOfStateWorkChain.get_scale_factors`` method.""" + inputs = generate_eos_inputs() + + # This conditional and conversion is necessary because for `aiida-core<2.0` the `list` type is not automatically + # serialized to a `List` node. Once we require `aiida-core>=2.0`, this can be removed. The reason we couldn't + # already simply turn the ``scaling_inputs`` into a ``orm.List`` is that during the parametrization done by pytest + # no AiiDA profile will have been loaded yet and so creating a node will raise an exception. + if 'scale_factors' in scaling_inputs and isinstance(scaling_inputs['scale_factors'], list): + scaling_inputs['scale_factors'] = orm.List(list=scaling_inputs['scale_factors']) + + inputs.update(scaling_inputs) + process = generate_workchain('common_workflows.eos', inputs) + assert process.get_scale_factors() == expected