Skip to content

Commit

Permalink
Simplify get_builder_from_protocol in ProjwfcBandsWorkChain
Browse files Browse the repository at this point in the history
This commit mainly simplifies the current version of the
`get_builder_from_protocol` method in `ProjwfcBandsWorkChain`.

Moreover, it adds support for overrides containing standard Python datatypes,
e.g. `kpoints_distance` specified as a float`.
  • Loading branch information
t-reents committed Jul 5, 2024
1 parent d6794cf commit 717756b
Showing 1 changed file with 4 additions and 20 deletions.
24 changes: 4 additions & 20 deletions src/aiida_wannier90_workflows/workflows/projwfcbands.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ
"""
from aiida_wannier90_workflows.utils.workflows.builder.submit import (
recursive_merge_builder,
recursive_merge_container,
)

type_check(pw_code, (str, int, orm.Code))
Expand All @@ -116,13 +115,9 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ
type_check(protocol, str, allow_none=True)
type_check(overrides, dict, allow_none=True)

# Prepare workchain builder
# # Prepare workchain builder
builder = cls.get_builder()

protocol_inputs = cls.get_protocol_inputs(
protocol=protocol, overrides=overrides
)

projwfc_overrides = None
if overrides:
projwfc_overrides = overrides.pop("projwfc", None)
Expand All @@ -137,25 +132,14 @@ def get_builder_from_protocol( # pylint: disable=arguments-differ

# By default do not run relax
pwbands_builder.pop("relax", None)
inputs = pwbands_builder._inputs(prune=True) # pylint: disable=protected-access

projwfc_builder = ProjwfcBaseWorkChain.get_builder_from_protocol(
projwfc_code, protocol=protocol, overrides=projwfc_overrides
)
projwfc_builder.pop("clean_workdir", None)

inputs["projwfc"] = projwfc_builder._inputs( # pylint: disable=protected-access
prune=True
)
inputs["projwfc"].pop("clean_workdir", None)

# Need to convert `clean_workdir` to `orm.Bool`
if "clean_workdir" in protocol_inputs:
protocol_inputs["clean_workdir"] = orm.Bool(
protocol_inputs["clean_workdir"]
)

inputs = recursive_merge_container(inputs, protocol_inputs)
builder = recursive_merge_builder(builder, inputs)
builder.projwfc = projwfc_builder
builder = recursive_merge_builder(builder, pwbands_builder)

return builder

Expand Down

0 comments on commit 717756b

Please sign in to comment.