Skip to content

Commit

Permalink
EosWorkChain: fix bug in get_scale_factors (#226)
Browse files Browse the repository at this point in the history
The return type of the `get_scale_factors` method was inconsistent. If
`scale_factors` is defined in the input, it would return a `List` node
of floats, but otherwise it would return a `list` of `Float` nodes. The
method is changed to always return a tuple of normal floats. The caller
then has the responsability of casting to a `Float` node if necessary.

The input spec of the `EosWorkChain` is also updated to use a serializer
for the `scale_factors`, `scale_count` and `scale_increment` inputs.
This allows a user to pass a simple base type and it will automatically
be converted to the corresponding AiiDA data node type. Note that this
feature does not yet work for `List` but this will be added soon to
`aiida-core`.

Note that the test needs a special condition to transform `list` inputs
into `orm.List` nodes. The reason for this special treatment is that for
`aiida-core<2.0` there is no automatic serializer yet for `list` types.
Once we upgrade the requirement, we can drop this special case.
  • Loading branch information
sphuber authored Oct 21, 2021
1 parent f83cbba commit 2042324
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 45 deletions.
18 changes: 11 additions & 7 deletions aiida_common_workflows/workflows/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,14 @@ def define(cls, spec):
# yapf: disable
super().define(spec)
spec.input('structure', valid_type=orm.StructureData, help='The structure at equilibrium volume.')
spec.input('scale_factors', valid_type=orm.List, required=False, validator=validate_scale_factors,
spec.input('scale_factors', valid_type=orm.List, required=False,
validator=validate_scale_factors, serializer=orm.to_aiida_type,
help='The list of scale factors at which the volume and total energy of the structure should be computed.')
spec.input('scale_count', valid_type=orm.Int, default=lambda: orm.Int(7), validator=validate_scale_count,
spec.input('scale_count', valid_type=orm.Int, default=lambda: orm.Int(7),
validator=validate_scale_count, serializer=orm.to_aiida_type,
help='The number of points to compute for the equation of state.')
spec.input('scale_increment', valid_type=orm.Float, default=lambda: orm.Float(0.02),
validator=validate_scale_increment,
validator=validate_scale_increment, serializer=orm.to_aiida_type,
help='The relative difference between consecutive scaling factors.')
spec.input_namespace('generator_inputs',
help='The inputs that will be passed to the input generator of the specified `sub_process`.')
Expand Down Expand Up @@ -127,11 +129,11 @@ def define(cls, spec):
def get_scale_factors(self):
"""Return the list of scale factors."""
if 'scale_factors' in self.inputs:
return self.inputs.scale_factors
return tuple(self.inputs.scale_factors)

count = self.inputs.scale_count.value
increment = self.inputs.scale_increment.value
return [orm.Float(1 + i * increment - (count - 1) * increment / 2) for i in range(count)]
return tuple(float(1 + i * increment - (count - 1) * increment / 2) for i in range(count))

def get_sub_workchain_builder(self, scale_factor, reference_workchain=None):
"""Return the builder for the relax workchain."""
Expand All @@ -149,7 +151,7 @@ def get_sub_workchain_builder(self, scale_factor, reference_workchain=None):

def run_init(self):
"""Run the first workchain."""
scale_factor = self.get_scale_factors()[0]
scale_factor = orm.Float(self.get_scale_factors()[0])
builder, structure = self.get_sub_workchain_builder(scale_factor)
self.report(f'submitting `{builder.process_class.__name__}` for scale_factor `{scale_factor}`')
self.ctx.reference_workchain = self.submit(builder)
Expand All @@ -166,7 +168,9 @@ def run_eos(self):
"""Run the sub process at each scale factor to compute the structure volume and total energy."""
for scale_factor in self.get_scale_factors()[1:]:
reference_workchain = self.ctx.reference_workchain
builder, structure = self.get_sub_workchain_builder(scale_factor, reference_workchain=reference_workchain)
builder, structure = self.get_sub_workchain_builder(
orm.Float(scale_factor), reference_workchain=reference_workchain
)
self.report(f'submitting `{builder.process_class.__name__}` for scale_factor `{scale_factor}`')
self.ctx.structures.append(structure)
self.to_context(children=append_(self.submit(builder)))
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,30 @@ def _generate_code(entry_point):
return _generate_code


@pytest.fixture
def generate_workchain():
"""Generate an instance of a ``WorkChain``."""

def _generate_workchain(entry_point, inputs):
"""Generate an instance of a ``WorkChain`` with the given entry point and inputs.
:param entry_point: entry point name of the work chain subclass.
:param inputs: inputs to be passed to process construction.
:return: a ``WorkChain`` instance.
"""
from aiida.engine.utils import instantiate_process
from aiida.manage.manager import get_manager
from aiida.plugins import WorkflowFactory

process_class = WorkflowFactory(entry_point)
runner = get_manager().get_runner()
process = instantiate_process(runner, process_class, **inputs)

return process

return _generate_workchain


@pytest.fixture
def generate_eos_node(generate_structure):
"""Generate an instance of ``EquationOfStateWorkChain``."""
Expand Down
100 changes: 62 additions & 38 deletions tests/workflows/eos/test_workchain_eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,35 @@ def common_relax_workchain(request) -> CommonRelaxWorkChain:
return WorkflowFactory(request.param)


@pytest.fixture
@pytest.mark.usefixtures('aiida_profile')
def generate_eos_inputs(generate_structure, generate_code):
"""Return a dictionary of defaults inputs for the ``EquationOfStateWorkChain``."""

def _generate_eos_inputs():
return {
'structure': generate_structure(symbols=('Si',)),
'sub_process_class': 'common_workflows.relax.quantum_espresso',
'generator_inputs': {
'protocol': 'fast',
'engines': {
'relax': {
'code': generate_code('quantumespresso.pw').store(),
'options': {
'resources': {
'num_machines': 1
}
}
}
},
'electronic_type': 'metal',
'relax_type': 'positions'
}
}

return _generate_eos_inputs


def test_validate_sub_process_class(ctx):
"""Test the `validate_sub_process_class` validator."""
for value in [None, WorkChain]:
Expand All @@ -41,25 +70,9 @@ def test_validate_sub_process_class_plugins(ctx, common_relax_workchain):


@pytest.mark.usefixtures('sssp')
def test_validate_inputs_scale(ctx, generate_code, generate_structure):
def test_validate_inputs_scale(ctx, generate_eos_inputs):
"""Test the ``validate_inputs`` validator for invalid scale inputs."""
base_values = {
'structure': generate_structure(symbols=('Si',)),
'sub_process_class': 'common_workflows.relax.quantum_espresso',
'generator_inputs': {
'engines': {
'relax': {
'code': generate_code('quantumespresso.pw'),
'options': {
'resources': {
'num_machines': 1
}
}
}
},
'electronic_type': 'metal'
}
}
base_values = generate_eos_inputs()

value = copy.deepcopy(base_values)
assert eos.validate_inputs(
Expand Down Expand Up @@ -88,27 +101,10 @@ def test_validate_inputs_scale(ctx, generate_code, generate_structure):


@pytest.mark.usefixtures('sssp')
def test_validate_inputs_generator_inputs(ctx, generate_code, generate_structure):
def test_validate_inputs_generator_inputs(ctx, generate_eos_inputs):
"""Test the ``validate_inputs`` validator for invalid generator inputs."""
value = {
'scale_factors': [],
'structure': generate_structure(symbols=('Si',)),
'sub_process_class': 'common_workflows.relax.quantum_espresso',
'generator_inputs': {
'engines': {
'relax': {
'code': generate_code('quantumespresso.pw'),
'options': {
'resources': {
'num_machines': 1
}
}
}
},
'electronic_type': 'metal'
}
}

value = generate_eos_inputs()
value['scale_factors'] = []
assert eos.validate_inputs(value, ctx) is None

value['generator_inputs']['electronic_type'] = 'invalid_value'
Expand Down Expand Up @@ -145,3 +141,31 @@ def test_validate_relax_type(ctx):
assert eos.validate_relax_type(
RelaxType.CELL, ctx
) == '`generator_inputs.relax_type`. Equation of state and relaxation with variable volume not compatible.'


@pytest.mark.parametrize(
'scaling_inputs, expected', (
({
'scale_factors': [0.98, 1.0, 1.02]
}, (0.98, 1.0, 1.02)),
({
'scale_count': 3,
'scale_increment': 0.02
}, (0.98, 1.0, 1.02)),
)
)
@pytest.mark.usefixtures('sssp')
def test_get_scale_factors(generate_workchain, generate_eos_inputs, scaling_inputs, expected):
"""Test the ``EquationOfStateWorkChain.get_scale_factors`` method."""
inputs = generate_eos_inputs()

# This conditional and conversion is necessary because for `aiida-core<2.0` the `list` type is not automatically
# serialized to a `List` node. Once we require `aiida-core>=2.0`, this can be removed. The reason we couldn't
# already simply turn the ``scaling_inputs`` into a ``orm.List`` is that during the parametrization done by pytest
# no AiiDA profile will have been loaded yet and so creating a node will raise an exception.
if 'scale_factors' in scaling_inputs and isinstance(scaling_inputs['scale_factors'], list):
scaling_inputs['scale_factors'] = orm.List(list=scaling_inputs['scale_factors'])

inputs.update(scaling_inputs)
process = generate_workchain('common_workflows.eos', inputs)
assert process.get_scale_factors() == expected

0 comments on commit 2042324

Please sign in to comment.