Skip to content

Commit

Permalink
♻️ Refactor Wannier90WorkChain.get_builder_from_protocol()
Browse files Browse the repository at this point in the history
Main changes include:

1. Start by gathering all the `inputs` from the various overrides. This means
   moving the overrides related to the input arguments (e.g.
   `plot_wannier_functions`) higher up, just before we start constructing the
   builder.
2. Make sure the user-specified overrides, provided by the `overrides` input
   argument, always take precedence.
3. Avoid using the `recursive_merge_builder` method by keeping the inputs in
   `dict` format. Conversion to node types should happen at the end in the
   `get_builder_from_protocol()` methods of the wrapped processes.
4. Remove double passing of the `pseudo_family` to the
   `get_builder_from_protocol()` call of the `Wannier90BaseWorkChain` through
   _both_ the `meta_parameters` in the `overrides` _and_ the `pseudo_family`
   input argument. We favour the latter.
5. Fix wrongfully passing the inputs (and hence overrides) of the `projwfc`
   namespace to the `pw2wannier90` namespace overrides.
  • Loading branch information
mbercx committed Nov 16, 2023
1 parent 48489f7 commit c0dc666
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 67 deletions.
113 changes: 46 additions & 67 deletions src/aiida_wannier90_workflows/workflows/wannier90.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base class for Wannierisation workflow."""
# pylint: disable=protected-access
import pathlib
import typing as ty

Expand All @@ -10,7 +11,10 @@

from aiida_quantumespresso.common.types import ElectronicType, SpinType
from aiida_quantumespresso.utils.mapping import prepare_process_inputs
from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin
from aiida_quantumespresso.workflows.protocols.utils import (
ProtocolMixin,
recursive_merge,
)
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain

from aiida_wannier90_workflows.common.types import (
Expand Down Expand Up @@ -219,7 +223,7 @@ def get_protocol_filepath(cls) -> pathlib.Path:

@classmethod
def get_protocol_overrides(cls) -> dict:
"""Return the ``pathlib.Path`` to the ``.yaml`` file that defines the protocols."""
"""Get the ``overrides`` for various input arguments of the ``get_builder_from_protocol()`` method."""
from importlib_resources import files
import yaml

Expand Down Expand Up @@ -310,10 +314,7 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
from aiida_wannier90_workflows.utils.workflows.builder.projections import (
guess_wannier_projection_types,
)
from aiida_wannier90_workflows.utils.workflows.builder.submit import (
check_codes,
recursive_merge_builder,
)
from aiida_wannier90_workflows.utils.workflows.builder.submit import check_codes

# Check function arguments
codes = check_codes(codes)
Expand Down Expand Up @@ -356,6 +357,27 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
)
type_check(external_projectors_path, str)

# Adapt overrides based on input arguments
# Note: if overrides are specified, they take precedence!
argument_overrides = Wannier90WorkChain.get_protocol_overrides()

if plot_wannier_functions:
overrides = recursive_merge(
argument_overrides["plot_wannier_functions"], overrides
)

if retrieve_hamiltonian:
overrides = recursive_merge(
argument_overrides["retrieve_hamiltonian"], overrides
)

if retrieve_matrices:
overrides = recursive_merge(
argument_overrides["retrieve_matrices"], overrides
)

inputs = cls.get_protocol_inputs(protocol=protocol, overrides=overrides)

if pseudo_family is None:
if spin_type == SpinType.SPIN_ORBIT:
# I use pseudo-dojo for SOC
Expand All @@ -365,24 +387,13 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
protocol=protocol
)["meta_parameters"]["pseudo_family"]

# Prepare workchain builder
# I need to use explicitly `Wannier90WorkChain.get_protocol_inputs()` instead of
# `cls.get_protocol_inputs()`, because for a subclass of Wannier90WorkChain,
# `cls.get_protocol_inputs()` will call the `get_protocol_inputs` of that subclass,
# which might be different from this base class.
builder = Wannier90WorkChain.get_builder()
inputs = Wannier90WorkChain.get_protocol_inputs(protocol, overrides)
builder = recursive_merge_builder(builder, inputs)

builder["structure"] = structure
builder = cls.get_builder()
builder.structure = structure
builder.clean_workdir = orm.Bool(inputs.get("clean_workdir"))

if not overrides:
overrides = {}

# Prepare wannier90
wannier_overrides = overrides.get("wannier90", {})
# Prepare wannier90 builder
wannier_overrides = inputs.get("wannier90", {})
wannier_overrides.setdefault("meta_parameters", {})
wannier_overrides["meta_parameters"].setdefault("pseudo_family", pseudo_family)
wannier_overrides["meta_parameters"].setdefault(
"exclude_semicore", exclude_semicore
)
Expand All @@ -401,18 +412,14 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
# Remove workchain excluded inputs
wannier_builder["wannier90"].pop("structure", None)
wannier_builder.pop("clean_workdir", None)
builder[
"wannier90"
] = wannier_builder._inputs( # pylint: disable=protected-access
prune=True
)
builder.wannier90 = wannier_builder._inputs(prune=True)

kpoints_distance = Wannier90BaseWorkChain.get_protocol_inputs(
protocol=protocol, overrides=wannier_overrides
)["meta_parameters"]["kpoints_distance"]

# Prepare scf
scf_overrides = overrides.get("scf", {})
# Prepare SCF builder
scf_overrides = inputs.get("scf", {})
scf_builder = get_scf_builder(
code=codes["pw"],
structure=structure,
Expand All @@ -422,14 +429,11 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
spin_type=spin_type,
overrides=scf_overrides,
)
# Remove workchain excluded inputs
scf_builder["pw"].pop("structure", None)
scf_builder.pop("clean_workdir", None)
builder["scf"] = scf_builder._inputs( # pylint: disable=protected-access
prune=True
)
builder.scf = scf_builder._inputs(prune=True)

# Prepare nscf
# Prepare NSCF builder
num_bands = wannier_builder["wannier90"]["parameters"]["num_bands"]
exclude_bands = (
wannier_builder["wannier90"]["parameters"]
Expand All @@ -441,7 +445,7 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
# Since the QE auto generated kpoints might be different from wannier90, here we explicitly
# generate a list of kpoint coordinates to avoid discrepencies.
kpoints = wannier_builder["wannier90"]["kpoints"]
nscf_overrides = overrides.get("nscf", {})
nscf_overrides = inputs.get("nscf", {})
nscf_builder = get_nscf_builder(
code=codes["pw"],
structure=structure,
Expand All @@ -455,11 +459,9 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
# Remove workchain excluded inputs
nscf_builder["pw"].pop("structure", None)
nscf_builder.pop("clean_workdir", None)
builder["nscf"] = nscf_builder._inputs( # pylint: disable=protected-access
prune=True
)
builder.nscf = nscf_builder._inputs(prune=True)

# Prepare projwfc
# Prepare projwfc builder
if projection_type == WannierProjectionType.SCDM:
run_projwfc = True
else:
Expand All @@ -470,24 +472,20 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
else:
run_projwfc = False
if run_projwfc:
projwfc_overrides = overrides.get("projwfc", {})
projwfc_overrides = inputs.get("projwfc", {})
projwfc_builder = ProjwfcBaseWorkChain.get_builder_from_protocol(
code=codes["projwfc"], protocol=protocol, overrides=projwfc_overrides
)
# Remove workchain excluded inputs
projwfc_builder.pop("clean_workdir", None)
builder[
"projwfc"
] = projwfc_builder._inputs( # pylint: disable=protected-access
prune=True
)
builder.projwfc = projwfc_builder._inputs(prune=True)

# Prepare pw2wannier90
# Prepare pw2wannier90 builder
exclude_projectors = None
if exclude_semicore:
pseudo_orbitals = get_pseudo_orbitals(builder["scf"]["pw"]["pseudos"])
exclude_projectors = get_semicore_list(structure, pseudo_orbitals)
pw2wannier90_overrides = overrides.get("projwfc", {})
pw2wannier90_overrides = inputs.get("pw2wannier90", {})
pw2wannier90_builder = Pw2wannier90BaseWorkChain.get_builder_from_protocol(
code=codes["pw2wannier90"],
protocol=protocol,
Expand All @@ -499,26 +497,7 @@ def get_builder_from_protocol( # pylint: disable=unused-argument
)
# Remove workchain excluded inputs
pw2wannier90_builder.pop("clean_workdir", None)
builder[
"pw2wannier90"
] = pw2wannier90_builder._inputs( # pylint: disable=protected-access
prune=True
)

# Apply several overrides
protocol_overrides = Wannier90WorkChain.get_protocol_overrides()
if plot_wannier_functions:
builder = recursive_merge_builder(
builder, protocol_overrides["plot_wannier_functions"]
)
if retrieve_hamiltonian:
builder = recursive_merge_builder(
builder, protocol_overrides["retrieve_hamiltonian"]
)
if retrieve_matrices:
builder = recursive_merge_builder(
builder, protocol_overrides["retrieve_matrices"]
)
builder.pw2wannier90 = pw2wannier90_builder._inputs(prune=True)

# A dictionary containing key info of Wannierisation and will be printed when the function returns.
if summary is None:
Expand Down
26 changes: 26 additions & 0 deletions tests/workflows/protocols/test_wannier90.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,29 @@ def test_force_parity(generate_builder_inputs, data_regression, serialize_builde

assert isinstance(builder, ProcessBuilder)
data_regression.check(serialize_builder(builder))


@pytest.mark.parametrize(
"overrides",
(
{"scf": {"pw": {"parameters": {"ELECTRONS": {"diagonalization": "paro"}}}}},
{"nscf": {"pw": {"parallelization": {"npool": 8}}}},
{"projwfc": {"projwfc": {"metadata": {"options": {"account": "infinite"}}}}},
{
"pw2wannier90": {
"pw2wannier90": {"parameters": {"inputpp": {"scdm_proj": False}}}
}
},
{"wannier90": {"auto_energy_windows_threshold": 0.01}},
),
)
def test_overrides(
generate_builder_inputs, data_regression, serialize_builder, overrides
):
"""Test specifying parameter ``overrides`` for the ``get_builder_from_protocol()`` method."""
inputs = generate_builder_inputs("Si")

builder = Wannier90WorkChain.get_builder_from_protocol(
**inputs, overrides=overrides
)
data_regression.check(serialize_builder(builder))

0 comments on commit c0dc666

Please sign in to comment.