diff --git a/src/aiida_wannier90_workflows/workflows/wannier90.py b/src/aiida_wannier90_workflows/workflows/wannier90.py index c9754fb..67812bd 100644 --- a/src/aiida_wannier90_workflows/workflows/wannier90.py +++ b/src/aiida_wannier90_workflows/workflows/wannier90.py @@ -1,4 +1,5 @@ """Base class for Wannierisation workflow.""" +# pylint: disable=protected-access import pathlib import typing as ty @@ -10,7 +11,10 @@ from aiida_quantumespresso.common.types import ElectronicType, SpinType from aiida_quantumespresso.utils.mapping import prepare_process_inputs -from aiida_quantumespresso.workflows.protocols.utils import ProtocolMixin +from aiida_quantumespresso.workflows.protocols.utils import ( + ProtocolMixin, + recursive_merge, +) from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain from aiida_wannier90_workflows.common.types import ( @@ -219,7 +223,7 @@ def get_protocol_filepath(cls) -> pathlib.Path: @classmethod def get_protocol_overrides(cls) -> dict: - """Return the ``pathlib.Path`` to the ``.yaml`` file that defines the protocols.""" + """Get the ``overrides`` for various input arguments of the ``get_builder_from_protocol()`` method.""" from importlib_resources import files import yaml @@ -310,10 +314,7 @@ def get_builder_from_protocol( # pylint: disable=unused-argument from aiida_wannier90_workflows.utils.workflows.builder.projections import ( guess_wannier_projection_types, ) - from aiida_wannier90_workflows.utils.workflows.builder.submit import ( - check_codes, - recursive_merge_builder, - ) + from aiida_wannier90_workflows.utils.workflows.builder.submit import check_codes # Check function arguments codes = check_codes(codes) @@ -356,6 +357,27 @@ def get_builder_from_protocol( # pylint: disable=unused-argument ) type_check(external_projectors_path, str) + # Adapt overrides based on input arguments + # Note: if overrides are specified, they take precedence! + argument_overrides = Wannier90WorkChain.get_protocol_overrides() + + if plot_wannier_functions: + overrides = recursive_merge( + argument_overrides["plot_wannier_functions"], overrides + ) + + if retrieve_hamiltonian: + overrides = recursive_merge( + argument_overrides["retrieve_hamiltonian"], overrides + ) + + if retrieve_matrices: + overrides = recursive_merge( + argument_overrides["retrieve_matrices"], overrides + ) + + inputs = cls.get_protocol_inputs(protocol=protocol, overrides=overrides) + if pseudo_family is None: if spin_type == SpinType.SPIN_ORBIT: # I use pseudo-dojo for SOC @@ -365,24 +387,13 @@ def get_builder_from_protocol( # pylint: disable=unused-argument protocol=protocol )["meta_parameters"]["pseudo_family"] - # Prepare workchain builder - # I need to use explicitly `Wannier90WorkChain.get_protocol_inputs()` instead of - # `cls.get_protocol_inputs()`, because for a subclass of Wannier90WorkChain, - # `cls.get_protocol_inputs()` will call the `get_protocol_inputs` of that subclass, - # which might be different from this base class. - builder = Wannier90WorkChain.get_builder() - inputs = Wannier90WorkChain.get_protocol_inputs(protocol, overrides) - builder = recursive_merge_builder(builder, inputs) - - builder["structure"] = structure + builder = cls.get_builder() + builder.structure = structure + builder.clean_workdir = orm.Bool(inputs.get("clean_workdir")) - if not overrides: - overrides = {} - - # Prepare wannier90 - wannier_overrides = overrides.get("wannier90", {}) + # Prepare wannier90 builder + wannier_overrides = inputs.get("wannier90", {}) wannier_overrides.setdefault("meta_parameters", {}) - wannier_overrides["meta_parameters"].setdefault("pseudo_family", pseudo_family) wannier_overrides["meta_parameters"].setdefault( "exclude_semicore", exclude_semicore ) @@ -401,18 +412,14 @@ def get_builder_from_protocol( # pylint: disable=unused-argument # Remove workchain excluded inputs wannier_builder["wannier90"].pop("structure", None) wannier_builder.pop("clean_workdir", None) - builder[ - "wannier90" - ] = wannier_builder._inputs( # pylint: disable=protected-access - prune=True - ) + builder.wannier90 = wannier_builder._inputs(prune=True) kpoints_distance = Wannier90BaseWorkChain.get_protocol_inputs( protocol=protocol, overrides=wannier_overrides )["meta_parameters"]["kpoints_distance"] - # Prepare scf - scf_overrides = overrides.get("scf", {}) + # Prepare SCF builder + scf_overrides = inputs.get("scf", {}) scf_builder = get_scf_builder( code=codes["pw"], structure=structure, @@ -422,14 +429,11 @@ def get_builder_from_protocol( # pylint: disable=unused-argument spin_type=spin_type, overrides=scf_overrides, ) - # Remove workchain excluded inputs scf_builder["pw"].pop("structure", None) scf_builder.pop("clean_workdir", None) - builder["scf"] = scf_builder._inputs( # pylint: disable=protected-access - prune=True - ) + builder.scf = scf_builder._inputs(prune=True) - # Prepare nscf + # Prepare NSCF builder num_bands = wannier_builder["wannier90"]["parameters"]["num_bands"] exclude_bands = ( wannier_builder["wannier90"]["parameters"] @@ -441,7 +445,7 @@ def get_builder_from_protocol( # pylint: disable=unused-argument # Since the QE auto generated kpoints might be different from wannier90, here we explicitly # generate a list of kpoint coordinates to avoid discrepencies. kpoints = wannier_builder["wannier90"]["kpoints"] - nscf_overrides = overrides.get("nscf", {}) + nscf_overrides = inputs.get("nscf", {}) nscf_builder = get_nscf_builder( code=codes["pw"], structure=structure, @@ -455,11 +459,9 @@ def get_builder_from_protocol( # pylint: disable=unused-argument # Remove workchain excluded inputs nscf_builder["pw"].pop("structure", None) nscf_builder.pop("clean_workdir", None) - builder["nscf"] = nscf_builder._inputs( # pylint: disable=protected-access - prune=True - ) + builder.nscf = nscf_builder._inputs(prune=True) - # Prepare projwfc + # Prepare projwfc builder if projection_type == WannierProjectionType.SCDM: run_projwfc = True else: @@ -470,24 +472,20 @@ def get_builder_from_protocol( # pylint: disable=unused-argument else: run_projwfc = False if run_projwfc: - projwfc_overrides = overrides.get("projwfc", {}) + projwfc_overrides = inputs.get("projwfc", {}) projwfc_builder = ProjwfcBaseWorkChain.get_builder_from_protocol( code=codes["projwfc"], protocol=protocol, overrides=projwfc_overrides ) # Remove workchain excluded inputs projwfc_builder.pop("clean_workdir", None) - builder[ - "projwfc" - ] = projwfc_builder._inputs( # pylint: disable=protected-access - prune=True - ) + builder.projwfc = projwfc_builder._inputs(prune=True) - # Prepare pw2wannier90 + # Prepare pw2wannier90 builder exclude_projectors = None if exclude_semicore: pseudo_orbitals = get_pseudo_orbitals(builder["scf"]["pw"]["pseudos"]) exclude_projectors = get_semicore_list(structure, pseudo_orbitals) - pw2wannier90_overrides = overrides.get("projwfc", {}) + pw2wannier90_overrides = inputs.get("pw2wannier90", {}) pw2wannier90_builder = Pw2wannier90BaseWorkChain.get_builder_from_protocol( code=codes["pw2wannier90"], protocol=protocol, @@ -499,26 +497,7 @@ def get_builder_from_protocol( # pylint: disable=unused-argument ) # Remove workchain excluded inputs pw2wannier90_builder.pop("clean_workdir", None) - builder[ - "pw2wannier90" - ] = pw2wannier90_builder._inputs( # pylint: disable=protected-access - prune=True - ) - - # Apply several overrides - protocol_overrides = Wannier90WorkChain.get_protocol_overrides() - if plot_wannier_functions: - builder = recursive_merge_builder( - builder, protocol_overrides["plot_wannier_functions"] - ) - if retrieve_hamiltonian: - builder = recursive_merge_builder( - builder, protocol_overrides["retrieve_hamiltonian"] - ) - if retrieve_matrices: - builder = recursive_merge_builder( - builder, protocol_overrides["retrieve_matrices"] - ) + builder.pw2wannier90 = pw2wannier90_builder._inputs(prune=True) # A dictionary containing key info of Wannierisation and will be printed when the function returns. if summary is None: diff --git a/tests/workflows/protocols/test_wannier90.py b/tests/workflows/protocols/test_wannier90.py index b4194f0..549d9c3 100644 --- a/tests/workflows/protocols/test_wannier90.py +++ b/tests/workflows/protocols/test_wannier90.py @@ -151,3 +151,29 @@ def test_force_parity(generate_builder_inputs, data_regression, serialize_builde assert isinstance(builder, ProcessBuilder) data_regression.check(serialize_builder(builder)) + + +@pytest.mark.parametrize( + "overrides", + ( + {"scf": {"pw": {"parameters": {"ELECTRONS": {"diagonalization": "paro"}}}}}, + {"nscf": {"pw": {"parallelization": {"npool": 8}}}}, + {"projwfc": {"projwfc": {"metadata": {"options": {"account": "infinite"}}}}}, + { + "pw2wannier90": { + "pw2wannier90": {"parameters": {"inputpp": {"scdm_proj": False}}} + } + }, + {"wannier90": {"auto_energy_windows_threshold": 0.01}}, + ), +) +def test_overrides( + generate_builder_inputs, data_regression, serialize_builder, overrides +): + """Test specifying parameter ``overrides`` for the ``get_builder_from_protocol()`` method.""" + inputs = generate_builder_inputs("Si") + + builder = Wannier90WorkChain.get_builder_from_protocol( + **inputs, overrides=overrides + ) + data_regression.check(serialize_builder(builder))