diff --git a/src/aiida_wannier90_workflows/workflows/optimize.py b/src/aiida_wannier90_workflows/workflows/optimize.py index e813976..6a94a93 100644 --- a/src/aiida_wannier90_workflows/workflows/optimize.py +++ b/src/aiida_wannier90_workflows/workflows/optimize.py @@ -31,11 +31,12 @@ def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument parameters = inputs["wannier90"]["wannier90"]["parameters"].get_dict() - if inputs["optimize_disproj"]: + optimize_disproj = inputs.get("optimize_disproj", True) + if optimize_disproj: if all(_ not in parameters for _ in ("dis_proj_min", "dis_proj_max")): return "Trying to optimize dis_proj_min/max but no dis_proj_min/max in wannier90 parameters?" - if "optimize_reference_bands" in inputs and not inputs["optimize_disproj"]: + if "optimize_reference_bands" in inputs and not optimize_disproj: warnings.warn( "`optimize_reference_bands` is provided but `optimize_disproj = False`?" ) @@ -46,22 +47,24 @@ def validate_inputs(inputs, ctx=None): # pylint: disable=unused-argument ): return "No `optimize_reference_bands` but `optimize_bands_distance_threshold` is set?" - if inputs["separate_plotting"]: - plot_inputs = [ - parameters.get(_, False) - for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access - ] + separate_plotting = inputs.get("separate_plotting", False) + plot_inputs = [ + parameters.get(_, False) + # for _ in Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS # pylint: disable=protected-access + for _ in ["wannier_plot"] + ] + if separate_plotting: if not any(plot_inputs): return ( "Trying to separate plotting routines but no " f"{'/'.join(Wannier90OptimizeWorkChain._WANNIER90_PLOT_INPUTS)} in wannier90 parameters?" # pylint: disable=protected-access ) - - if inputs["optimize_disproj"] and not inputs["separate_plotting"]: - warnings.warn( - "`optimize_disproj = True` but `separate_plotting = False`. For optimizing projectability " - "disentanglement, it is highly recommended to run the plotting mode in a separate step." - ) + else: + if optimize_disproj and any(plot_inputs): + warnings.warn( + "`optimize_disproj = True` but `separate_plotting = False`. For optimizing projectability " + "disentanglement, it is highly recommended to run the plotting mode in a separate step." + ) return None @@ -262,6 +265,11 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ ) kwargs.setdefault("projection_type", WannierProjectionType.ATOMIC_PROJECTORS_QE) + if reference_bands and kwargs.get("bands_kpoints", None) is None: + warnings.warn( + "It is recommended to provide both `reference_bands` and `bands_kpoints` so that" + " the seekpath step can be skipped." + ) parent_builder = super().get_builder_from_protocol(codes, structure, **kwargs) if reference_bands is not None: diff --git a/src/aiida_wannier90_workflows/workflows/wannier90.py b/src/aiida_wannier90_workflows/workflows/wannier90.py index 3547f2c..c9754fb 100644 --- a/src/aiida_wannier90_workflows/workflows/wannier90.py +++ b/src/aiida_wannier90_workflows/workflows/wannier90.py @@ -396,6 +396,7 @@ def get_builder_from_protocol( # pylint: disable=unused-argument projection_type=projection_type, disentanglement_type=disentanglement_type, frozen_type=frozen_type, + pseudo_family=pseudo_family, ) # Remove workchain excluded inputs wannier_builder["wannier90"].pop("structure", None)