Skip to content

Commit

Permalink
Workflows: add the skip_relax_iterations logic
Browse files Browse the repository at this point in the history
We implement the logic to skip a certain number of relax iterations.
For these iterations the check on convergence is also skipped.
Some fixes are also added.
  • Loading branch information
bastonero committed Jun 6, 2023
1 parent 043fece commit 9db7ced
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
def structure_relabel_kinds(
hubbard_structure: HubbardStructureData,
hubbard: Dict,
magnetization: Dict | None = None,
magnetization: dict | None = None,
) -> Dict:
"""Create a clone of the given structure but with new kinds, based on the new hubbard sites.
Expand Down
192 changes: 81 additions & 111 deletions src/aiida_quantumespresso_hp/workflows/hubbard.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def get_separated_parameters(
return onsites, intersites


def validate_positive(value, _):
"""Validate that the value is positive."""
if value.value < 0:
return 'the value must be positive.'


def validate_inputs(inputs, _):
"""Validate the entire inputs."""
parameters = AttributeDict(inputs).scf.pw.parameters.get_dict()
Expand Down Expand Up @@ -100,74 +106,36 @@ class SelfConsistentHubbardWorkChain(WorkChain, ProtocolMixin):
@classmethod
def define(cls, spec):
"""Define the specifications of the process."""
# yapf: disable
super().define(spec)

spec.input('hubbard_structure', valid_type=HubbardStructureData)
spec.input(
'tolerance_onsite',
valid_type=orm.Float,
default=lambda: orm.Float(0.1),
help=(
'Tolerance value for self-consistent calculation of Hubbard U. '
'In case of DFT+U+V calculation, it refers to the diagonal elements (i.e. on-site).'
)
)
spec.input(
'tolerance_intersite',
valid_type=orm.Float,
default=lambda: orm.Float(0.01),
help=(
'Tolerance value for self-consistent DFT+U+V calculation. '
'It refers to the only off-diagonal elements V.'
)
)
spec.input(
'skip_first_relax',
valid_type=orm.Bool,
default=lambda: orm.Bool(False),
help='If True, skip the first relaxation'
)
spec.input(
'relax_frequency',
valid_type=orm.Int,
required=False,
help='Integer value referring to the number of iterations to wait before performing the `relax` step.'
)
spec.expose_inputs(
PwRelaxWorkChain,
namespace='relax',
exclude=(
'clean_workdir',
'structure',
),
namespace_options={
'required': False,
'populate_defaults': False,
'help': 'Inputs for the `PwRelaxWorkChain` that, when defined, will iteratively relax the structure.'
}
)
spec.expose_inputs(PwBaseWorkChain, namespace='scf', exclude=(
'clean_workdir',
'pw.structure',
))
spec.expose_inputs(
HpWorkChain,
namespace='hubbard',
exclude=(
'clean_workdir',
'hp.parent_scf',
'hp.parent_hp',
'hp.hubbard_structure',
)
)
spec.input('max_iterations', valid_type=orm.Int, default=lambda: orm.Int(10))
spec.input('meta_convergence', valid_type=orm.Bool, default=lambda: orm.Bool(False))
spec.input(
'clean_workdir',
valid_type=orm.Bool,
default=lambda: orm.Bool(True),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.'
)
spec.input('hubbard_structure', valid_type=HubbardStructureData,
help=('The HubbardStructureData containing the initialized parameters for triggering '
'the Hubbard atoms which the `hp.x` code will perturbe.'))
spec.input('tolerance_onsite', valid_type=orm.Float, default=lambda: orm.Float(0.1),
help=('Tolerance value for self-consistent calculation of Hubbard U. '
'In case of DFT+U+V calculation, it refers to the diagonal elements (i.e. on-site).'))
spec.input('tolerance_intersite', valid_type=orm.Float, default=lambda: orm.Float(0.01),
help=('Tolerance value for self-consistent DFT+U+V calculation. '
'It refers to the only off-diagonal elements V.'))
spec.input('skip_relax_iterations', valid_type=orm.Int, required=False, validator=validate_positive,
help=('The number of iterations for skipping the `relax` '
'step without performing check on parameters convergence.'))
spec.input('relax_frequency', valid_type=orm.Int, required=False, validator=validate_positive,
help='Integer value referring to the number of iterations to wait before performing the `relax` step.')
spec.expose_inputs(PwRelaxWorkChain, namespace='relax',
exclude=('clean_workdir', 'structure', 'base_final_scf'),
namespace_options={'required': False, 'populate_defaults': False,
'help': 'Inputs for the `PwRelaxWorkChain` that, when defined, will iteratively relax the structure.'})
spec.expose_inputs(PwBaseWorkChain, namespace='scf',
exclude=('clean_workdir','pw.structure'))
spec.expose_inputs(HpWorkChain, namespace='hubbard',
exclude=('clean_workdir', 'hp.parent_scf', 'hp.parent_hp', 'hp.hubbard_structure'))
spec.input('max_iterations', valid_type=orm.Int, default=lambda: orm.Int(10),
help='Maximum number of iterations of the (relax-)scf-hp cycle.')
spec.input('meta_convergence', valid_type=orm.Bool, default=lambda: orm.Bool(False),
help='Whether performing the self-consistent cycle. If False, it will stop at the first iteration.')
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(True),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.')

spec.inputs.validator = validate_inputs
spec.inputs['hubbard']['hp'].validator = None
Expand All @@ -188,44 +156,30 @@ def define(cls, spec):
),
cls.run_hp,
cls.inspect_hp,
if_(cls.should_check_convergence)(cls.check_convergence,),
if_(cls.should_check_convergence)(
cls.check_convergence,
),
),
cls.run_results,
)

spec.output(
'hubbard_structure',
valid_type=HubbardStructureData,
required=False,
help='The Hubbard structure containing the structure and associated Hubbard parameters.'
)

spec.exit_code(
330,
'ERROR_FAILED_TO_DETERMINE_PSEUDO_POTENTIAL',
message='Failed to determine the correct pseudo potential after the structure changed its kind names.'
)
spec.exit_code(
401, 'ERROR_SUB_PROCESS_FAILED_RECON', message='The reconnaissance PwBaseWorkChain sub process failed'
)
spec.exit_code(
402,
'ERROR_SUB_PROCESS_FAILED_RELAX',
message='The PwRelaxWorkChain sub process failed in iteration {iteration}'
)
spec.exit_code(
403,
'ERROR_SUB_PROCESS_FAILED_SCF',
message='The scf PwBaseWorkChain sub process failed in iteration {iteration}'
)
spec.exit_code(
404, 'ERROR_SUB_PROCESS_FAILED_HP', message='The HpWorkChain sub process failed in iteration {iteration}'
)
spec.exit_code(
405, 'ERROR_NON_INTEGER_TOT_MAGNETIZATION',
spec.output('hubbard_structure', valid_type=HubbardStructureData, required=False,
help='The Hubbard structure containing the structure and associated Hubbard parameters.')

spec.exit_code(330, 'ERROR_FAILED_TO_DETERMINE_PSEUDO_POTENTIAL',
message='Failed to determine the correct pseudo potential after the structure changed its kind names.')
spec.exit_code(401, 'ERROR_SUB_PROCESS_FAILED_RECON',
message='The reconnaissance PwBaseWorkChain sub process failed')
spec.exit_code(402, 'ERROR_SUB_PROCESS_FAILED_RELAX',
message='The PwRelaxWorkChain sub process failed in iteration {iteration}')
spec.exit_code(403, 'ERROR_SUB_PROCESS_FAILED_SCF',
message='The scf PwBaseWorkChain sub process failed in iteration {iteration}')
spec.exit_code(404, 'ERROR_SUB_PROCESS_FAILED_HP',
message='The HpWorkChain sub process failed in iteration {iteration}')
spec.exit_code(405, 'ERROR_NON_INTEGER_TOT_MAGNETIZATION',
message='The scf PwBaseWorkChain sub process in iteration {iteration}'\
'returned a non integer total magnetization (threshold exceeded).'
)
'returned a non integer total magnetization (threshold exceeded).')
# yapf: enable

@classmethod
def get_protocol_filepath(cls):
Expand Down Expand Up @@ -289,12 +243,13 @@ def get_builder_from_protocol(

if 'relax_frequency' in inputs:
builder.relax_frequency = orm.Int(inputs['relax_frequency'])
if 'skip_relax_iterations' in inputs:
builder.skip_relax_iterations = orm.Int(inputs['skip_relax_iterations'])

builder.hubbard_structure = hubbard_structure
builder.relax = relax
builder.scf = scf
builder.hubbard = hubbard
builder.skip_first_relax = orm.Bool(inputs['skip_first_relax'])
builder.tolerance_onsite = orm.Float(inputs['tolerance_onsite'])
builder.tolerance_intersite = orm.Float(inputs['tolerance_intersite'])
builder.max_iterations = orm.Int(inputs['max_iterations'])
Expand All @@ -312,7 +267,9 @@ def setup(self):
self.ctx.is_insulator = None
self.ctx.is_magnetic = False
self.ctx.iteration = 0
self.ctx.skip_first_relax = self.inputs.skip_first_relax.value
self.ctx.skip_relax_iterations = 0
if 'skip_relax_iterations' in self.inputs:
self.ctx.skip_relax_iterations = self.inputs.skip_relax_iterations.value
self.ctx.relax_frequency = 1
if 'relax_frequency' in self.inputs:
self.ctx.relax_frequency = self.inputs.relax_frequency.value
Expand Down Expand Up @@ -342,23 +299,35 @@ def should_run_relax(self):
if 'relax' not in self.inputs:
return False

if self.ctx.skip_first_relax:
self.ctx.skip_first_relax = False # only the first one will be skipped
self.report('`skip_first_relax` is set to `True`. Skipping first relaxation.')
if self.ctx.iteration <= self.ctx.skip_relax_iterations:
self.report((
f'`skip_relax_iterations` is set to {self.ctx.skip_relax_iterations}. '
f'Skipping relaxation for iteration {self.ctx.iteration}.'
))
return False

if self.ctx.iteration % self.ctx.relax_frequency != 0:
self.report((
f'`relax_frequency` is set to {self.ctx.relax_frequency}. '
f'Skipping relaxation for iteration {self.ctx.iteration }.'
f'Skipping relaxation for iteration {self.ctx.iteration}.'
))
return False

return 'relax' in self.inputs
return True

def should_check_convergence(self):
"""Return whether to check the convergence of Hubbard parameters."""
return self.inputs.meta_convergence.value
if not self.inputs.meta_convergence.value:
return False

if self.ctx.iteration <= self.ctx.skip_relax_iterations:
self.report((
f'`skip_relax_iterations` is set to {self.ctx.skip_relax_iterations}. '
f'Skipping convergence check for iteration {self.ctx.iteration}.'
))
return False

return True

def should_run_iteration(self):
"""Return whether a new process should be run."""
Expand Down Expand Up @@ -601,10 +570,11 @@ def inspect_hp(self):
self.report(f'hp.x in iteration {self.ctx.iteration} failed with exit status {workchain.exit_status}')
return self.exit_codes.ERROR_SUB_PROCESS_FAILED_HP.format(iteration=self.ctx.iteration)

if not self.inputs.meta_convergence:
if not self.should_check_convergence():
self.ctx.current_hubbard_structure = workchain.outputs.hubbard_structure
self.report('meta convergence is switched off, so not checking convergence of Hubbard parameters.')
self.ctx.is_converged = True
if not self.inputs.meta_convergence:
self.report('meta convergence is switched off, so not checking convergence of Hubbard parameters.')
self.ctx.is_converged = True

def check_convergence(self):
"""Check the convergence of the Hubbard parameters."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ default_inputs:
meta_convergence: True
tolerance_onsite: 0.1
tolerance_intersite: 0.01
skip_first_relax: False
scf:
kpoints_distance: 0.4

Expand Down
2 changes: 1 addition & 1 deletion tests/workflows/protocols/test_hubbard.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_default(fixture_code, data_regression, generate_hubbard_structure, seri
'tolerance_intersite': 1
},
{
'skip_first_relax': True
'skip_relax_iterations': 2
},
{
'relax_frequency': 3
Expand Down
1 change: 0 additions & 1 deletion tests/workflows/protocols/test_hubbard/test_default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,5 @@ scf:
Co: Co<md5=04edd96127402ab6ffc358660b52a2db>
Li: Li<md5=90ac4658c7606c7ad16e40ce66db5a86>
O: O<md5=721f9895631356f7d4610e60de16fd63>
skip_first_relax: false
tolerance_intersite: 0.01
tolerance_onsite: 0.1
Loading

0 comments on commit 9db7ced

Please sign in to comment.