Skip to content

Commit

Permalink
Controlled submission of concurrent HpBaseWorkChains
Browse files Browse the repository at this point in the history
The workchains parallelizing atoms and q-points could not control 
the amount of concurrent sub-processes, specifically `HpBaseWorkChain`.
The user can now optionally control the maximum number of simultaneously
submitted `HpBaseWorkChains`. 

This is extremely useful in case one parallelizes over atoms and q-points,
giving more control to the user on the resources utilization.
  • Loading branch information
t-reents authored Dec 15, 2023
1 parent 296ba48 commit 8be0c77
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 32 deletions.
27 changes: 27 additions & 0 deletions src/aiida_quantumespresso_hp/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""General utilies."""
from __future__ import annotations

from typing import List


def set_tot_magnetization(input_parameters: dict, tot_magnetization: float) -> bool:
"""Set the total magnetization based on its value and the input parameters.
Expand Down Expand Up @@ -37,3 +39,28 @@ def is_perturb_only_atom(parameters: dict) -> int | None:
break

return match


def distribute_base_workchains(n_atoms: int, n_total: int) -> List[int]:
"""Distribute the maximum number of `BaseWorkChains` to be launched.
The number of `BaseWorkChains` will be distributed over the number of atoms.
The elements of the resulting list correspond to the number of q-point
`BaseWorkChains` to be launched for each atom, in case q-point parallelization
is used. Otherwise, the method will only take care of limitting the number
of `HpParallelizeAtomsWorkChain` to be launched in parallel.
:param n_atoms: The number of atoms.
:param n_total: The number of base workchains to be launched.
:return: The number of base workchains to be launched for each atom.
"""
quotient = n_total // n_atoms
remainder = n_total % n_atoms
n_distributed = [quotient] * n_atoms

for i in range(remainder):
n_distributed[i] += 1

n_distributed = [x for x in n_distributed if x != 0]

return n_distributed
5 changes: 5 additions & 0 deletions src/aiida_quantumespresso_hp/workflows/hp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def define(cls, spec):
'for any non-periodic directions.')
spec.input('parallelize_atoms', valid_type=orm.Bool, default=lambda: orm.Bool(False))
spec.input('parallelize_qpoints', valid_type=orm.Bool, default=lambda: orm.Bool(False))
spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False)
spec.outline(
cls.validate_qpoints,
if_(cls.should_parallelize_atoms)(
Expand Down Expand Up @@ -106,6 +107,8 @@ def get_builder_from_protocol(cls, code, protocol=None, parent_scf_folder=None,
data['parallelize_atoms'] = orm.Bool(inputs['parallelize_atoms'])
if 'parallelize_qpoints' in inputs:
data['parallelize_qpoints'] = orm.Bool(inputs['parallelize_qpoints'])
if 'max_concurrent_base_workchains' in inputs:
data['max_concurrent_base_workchains'] = orm.Int(inputs['max_concurrent_base_workchains'])

builder = cls.get_builder()
builder._data = data # pylint: disable=protected-access
Expand Down Expand Up @@ -163,6 +166,8 @@ def run_parallel_workchain(self):
inputs.clean_workdir = self.inputs.clean_workdir
inputs.parallelize_qpoints = self.inputs.parallelize_qpoints
inputs.hp.qpoints = self.ctx.qpoints
if 'max_concurrent_base_workchains' in self.inputs:
inputs.max_concurrent_base_workchains = self.inputs.max_concurrent_base_workchains
running = self.submit(HpParallelizeAtomsWorkChain, **inputs)
self.report(f'running in parallel, launching HpParallelizeAtomsWorkChain<{running.pk}>')
return ToContext(workchain=running)
Expand Down
34 changes: 24 additions & 10 deletions src/aiida_quantumespresso_hp/workflows/hp/parallelize_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
"""Work chain to launch a Quantum Espresso hp.x calculation parallelizing over the Hubbard atoms."""
from aiida import orm
from aiida.common import AttributeDict
from aiida.engine import WorkChain
from aiida.engine import WorkChain, while_
from aiida.plugins import CalculationFactory, WorkflowFactory

from aiida_quantumespresso_hp.utils.general import distribute_base_workchains

PwCalculation = CalculationFactory('quantumespresso.pw')
HpCalculation = CalculationFactory('quantumespresso.hp')
HpBaseWorkChain = WorkflowFactory('quantumespresso.hp.base')
Expand All @@ -21,12 +23,15 @@ def define(cls, spec):
super().define(spec)
spec.expose_inputs(HpBaseWorkChain, exclude=('only_initialization', 'clean_workdir'))
spec.input('parallelize_qpoints', valid_type=orm.Bool, default=lambda: orm.Bool(False))
spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False)
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.')
spec.outline(
cls.run_init,
cls.inspect_init,
cls.run_atoms,
while_(cls.should_run_atoms)(
cls.run_atoms,
),
cls.inspect_atoms,
cls.run_final,
cls.inspect_final,
Expand Down Expand Up @@ -66,28 +71,37 @@ def inspect_init(self):
self.report(f'initialization work chain {workchain} failed with status {workchain.exit_status}, aborting.')
return self.exit_codes.ERROR_INITIALIZATION_WORKCHAIN_FAILED

def run_atoms(self):
"""Run a separate `HpBaseWorkChain` for each of the defined Hubbard atoms."""
workchain = self.ctx.initialization

output_params = workchain.outputs.parameters.get_dict()
hubbard_sites = output_params['hubbard_sites']
self.ctx.hubbard_sites = list(output_params['hubbard_sites'].items())

def should_run_atoms(self):
"""Return whether there are more atoms to run."""
return len(self.ctx.hubbard_sites) > 0

def run_atoms(self):
"""Run a separate `HpBaseWorkChain` for each of the defined Hubbard atoms."""
parallelize_qpoints = self.inputs.parallelize_qpoints.value
workflow = HpParallelizeQpointsWorkChain if parallelize_qpoints else HpBaseWorkChain

for site_index, site_kind in hubbard_sites.items():
max_concurrent_base_workchains_sites = [-1] * len(self.ctx.hubbard_sites)
if 'max_concurrent_base_workchains' in self.inputs:
max_concurrent_base_workchains_sites = distribute_base_workchains(
len(self.ctx.hubbard_sites), self.inputs.max_concurrent_base_workchains.value
)

for max_concurrent_base_workchains_site in max_concurrent_base_workchains_sites:
site_index, site_kind = self.ctx.hubbard_sites.pop(0)
do_only_key = f'perturb_only_atom({site_index})'
key = f'atom_{site_index}'

inputs = AttributeDict(self.exposed_inputs(HpBaseWorkChain))
inputs.clean_workdir = self.inputs.clean_workdir
inputs.hp.parameters = inputs.hp.parameters.get_dict()
inputs.hp.parameters['INPUTHP'][do_only_key] = True
inputs.hp.parameters = orm.Dict(dict=inputs.hp.parameters)
inputs.hp.parameters = orm.Dict(inputs.hp.parameters)
inputs.metadata.call_link_label = key

if parallelize_qpoints and max_concurrent_base_workchains_site != -1:
inputs.max_concurrent_base_workchains = orm.Int(max_concurrent_base_workchains_site)
node = self.submit(workflow, **inputs)
self.to_context(**{key: node})
name = workflow.__name__
Expand Down
21 changes: 14 additions & 7 deletions src/aiida_quantumespresso_hp/workflows/hp/parallelize_qpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""Work chain to launch a Quantum Espresso hp.x calculation parallelizing over the Hubbard atoms."""
from aiida import orm
from aiida.common import AttributeDict
from aiida.engine import WorkChain
from aiida.engine import WorkChain, while_
from aiida.plugins import CalculationFactory, WorkflowFactory

from aiida_quantumespresso_hp.utils.general import is_perturb_only_atom
Expand All @@ -29,12 +29,15 @@ def define(cls, spec):
# yapf: disable
super().define(spec)
spec.expose_inputs(HpBaseWorkChain, exclude=('only_initialization', 'clean_workdir'))
spec.input('max_concurrent_base_workchains', valid_type=orm.Int, required=False)
spec.input('clean_workdir', valid_type=orm.Bool, default=lambda: orm.Bool(False),
help='If `True`, work directories of all called calculation will be cleaned at the end of execution.')
spec.outline(
cls.run_init,
cls.inspect_init,
cls.run_qpoints,
while_(cls.should_run_qpoints)(
cls.run_qpoints,
),
cls.inspect_qpoints,
cls.run_final,
cls.results
Expand Down Expand Up @@ -75,14 +78,18 @@ def inspect_init(self):
self.report(f'initialization work chain {workchain} failed with status {workchain.exit_status}, aborting.')
return self.exit_codes.ERROR_INITIALIZATION_WORKCHAIN_FAILED

def run_qpoints(self):
"""Run a separate `HpBaseWorkChain` for each of the q points."""
workchain = self.ctx.initialization
self.ctx.qpoints = list(range(workchain.outputs.parameters.dict.number_of_qpoints))

number_of_qpoints = workchain.outputs.parameters.dict.number_of_qpoints
def should_run_qpoints(self):
"""Return whether there are more q points to run."""
return len(self.ctx.qpoints) > 0

for qpoint_index in range(number_of_qpoints):
def run_qpoints(self):
"""Run a separate `HpBaseWorkChain` for each of the q points."""
n_base_parallel = self.inputs.max_concurrent_base_workchains.value if 'max_concurrent_base_workchains' in self.inputs else len(self.ctx.qpoints)

for _ in self.ctx.qpoints[:n_base_parallel]:
qpoint_index = self.ctx.qpoints.pop(0)
key = f'qpoint_{qpoint_index + 1}' # to keep consistency with QE
inputs = AttributeDict(self.exposed_inputs(HpBaseWorkChain))
inputs.clean_workdir = self.inputs.clean_workdir
Expand Down
12 changes: 12 additions & 0 deletions tests/utils/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,15 @@ def test_is_perturb_only_atom():

parameters = {'perturb_only_atom(1)': False}
assert is_perturb_only_atom(parameters) is None


def test_distribute_base_wcs():
"""Test the `distribute_base_wcs` function."""
from aiida_quantumespresso_hp.utils.general import distribute_base_workchains

assert distribute_base_workchains(1, 1) == [1]
assert distribute_base_workchains(1, 2) == [2]
assert distribute_base_workchains(2, 1) == [1]
assert distribute_base_workchains(2, 2) == [1, 1]
assert distribute_base_workchains(2, 3) == [2, 1]
assert distribute_base_workchains(7, 5) == [1] * 5
44 changes: 36 additions & 8 deletions tests/workflows/hp/test_parallelize_atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
def generate_workchain_atoms(generate_workchain, generate_inputs_hp, generate_hubbard_structure):
"""Generate an instance of a `HpParallelizeAtomsWorkChain`."""

def _generate_workchain_atoms(inputs=None, parallelize_qpoints=False):
from aiida.orm import Bool
def _generate_workchain_atoms(hp_inputs=None, parallelize_qpoints=False, max_concurrent_base_workchains=None):
from aiida.orm import Bool, Int
entry_point = 'quantumespresso.hp.parallelize_atoms'
inputs = generate_inputs_hp(inputs=inputs)
inputs['hubbard_structure'] = generate_hubbard_structure()
inputs['parallelize_qpoints'] = Bool(parallelize_qpoints)
process = generate_workchain(entry_point, {'hp': inputs})
hp_inputs = generate_inputs_hp(inputs=hp_inputs)
hp_inputs['hubbard_structure'] = generate_hubbard_structure()
hp_inputs['parallelize_qpoints'] = Bool(parallelize_qpoints)
inputs = {'hp': hp_inputs}
if max_concurrent_base_workchains is not None:
inputs['max_concurrent_base_workchains'] = Int(max_concurrent_base_workchains)
process = generate_workchain(entry_point, inputs)

return process

Expand Down Expand Up @@ -69,19 +72,44 @@ def test_run_atoms(generate_workchain_atoms, generate_hp_workchain_node):
"""Test `HpParallelizeAtomsWorkChain.run_atoms`."""
process = generate_workchain_atoms()
process.ctx.initialization = generate_hp_workchain_node()

output_params = process.ctx.initialization.outputs.parameters.get_dict()
process.ctx.hubbard_sites = list(output_params['hubbard_sites'].items())
process.run_atoms()

assert 'atom_1' in process.ctx
assert 'atom_2' in process.ctx


@pytest.mark.usefixtures('aiida_profile')
def test_run_atoms_max_concurrent(generate_workchain_atoms, generate_hp_workchain_node):
"""Test `HpParallelizeAtomsWorkChain.run_atoms`.
The number of concurrent `BaseWorkChains` is limited to `1`.
"""
process = generate_workchain_atoms(max_concurrent_base_workchains=1)
process.ctx.initialization = generate_hp_workchain_node()
output_params = process.ctx.initialization.outputs.parameters.get_dict()
process.ctx.hubbard_sites = list(output_params['hubbard_sites'].items())

assert process.should_run_atoms()
process.run_atoms()
assert 'atom_1' in process.ctx
assert 'atom_2' not in process.ctx
assert process.should_run_atoms()
process.run_atoms()
assert 'atom_1' in process.ctx
assert 'atom_2' in process.ctx

assert not process.should_run_atoms()


@pytest.mark.usefixtures('aiida_profile')
def test_run_atoms_with_qpoints(generate_workchain_atoms, generate_hp_workchain_node):
"""Test `HpParallelizeAtomsWorkChain.run_atoms` with q point parallelization."""
process = generate_workchain_atoms()
process.ctx.initialization = generate_hp_workchain_node()

output_params = process.ctx.initialization.outputs.parameters.get_dict()
process.ctx.hubbard_sites = list(output_params['hubbard_sites'].items())
process.run_atoms()

# Don't know how to test something like the following
Expand Down
41 changes: 34 additions & 7 deletions tests/workflows/hp/test_parallelize_qpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
def generate_workchain_qpoints(generate_workchain, generate_inputs_hp, generate_hubbard_structure):
"""Generate an instance of a `HpParallelizeQpointsWorkChain`."""

def _generate_workchain_qpoints(inputs=None):
def _generate_workchain_qpoints(hp_inputs=None, max_concurrent_base_workchains=None):
from aiida.orm import Int
entry_point = 'quantumespresso.hp.parallelize_qpoints'

if inputs is None:
inputs = {'perturb_only_atom(1)': True}
if hp_inputs is None:
hp_inputs = {'perturb_only_atom(1)': True}

inputs = generate_inputs_hp(inputs=inputs)
inputs['hubbard_structure'] = generate_hubbard_structure()
process = generate_workchain(entry_point, {'hp': inputs})
hp_inputs = generate_inputs_hp(inputs=hp_inputs)
hp_inputs['hubbard_structure'] = generate_hubbard_structure()
inputs = {'hp': hp_inputs}
if max_concurrent_base_workchains is not None:
inputs['max_concurrent_base_workchains'] = Int(max_concurrent_base_workchains)
process = generate_workchain(entry_point, inputs)

return process

Expand Down Expand Up @@ -56,7 +60,7 @@ def test_validate_inputs_invalid_parameters(generate_workchain_qpoints):
"""Test `HpParallelizeQpointsWorkChain.validate_inputs`."""
match = r'The parameters in `hp.parameters` do not specify the required key `INPUTHP.pertub_only_atom`'
with pytest.raises(ValueError, match=match):
generate_workchain_qpoints(inputs={})
generate_workchain_qpoints(hp_inputs={})


@pytest.mark.usefixtures('aiida_profile')
Expand All @@ -73,13 +77,36 @@ def test_run_qpoints(generate_workchain_qpoints, generate_hp_workchain_node):
"""Test `HpParallelizeQpointsWorkChain.run_qpoints`."""
process = generate_workchain_qpoints()
process.ctx.initialization = generate_hp_workchain_node()
process.ctx.qpoints = list(range(process.ctx.initialization.outputs.parameters.dict.number_of_qpoints))

process.run_qpoints()
# to keep consistency with QE we start from 1
assert 'qpoint_1' in process.ctx
assert 'qpoint_2' in process.ctx


@pytest.mark.usefixtures('aiida_profile')
def test_run_qpoints_max_concurrent(generate_workchain_qpoints, generate_hp_workchain_node):
"""Test `HpParallelizeQpointsWorkChain.run_qpoints`.
The number of concurrent `BaseWorkChains` is limited to `1`.
"""
process = generate_workchain_qpoints(max_concurrent_base_workchains=1)
process.ctx.initialization = generate_hp_workchain_node()
process.ctx.qpoints = list(range(process.ctx.initialization.outputs.parameters.dict.number_of_qpoints))

assert process.should_run_qpoints()
process.run_qpoints()
assert 'qpoint_1' in process.ctx
assert 'qpoint_2' not in process.ctx

assert process.should_run_qpoints()
process.run_qpoints()
assert 'qpoint_1' in process.ctx
assert 'qpoint_2' in process.ctx
assert not process.should_run_qpoints()


@pytest.mark.usefixtures('aiida_profile')
def test_inspect_init(generate_workchain_qpoints, generate_hp_workchain_node):
"""Test `HpParallelizeQpointsWorkChain.inspect_init`."""
Expand Down

0 comments on commit 8be0c77

Please sign in to comment.