diff --git a/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py b/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py index 26d160c7..0770b109 100644 --- a/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py +++ b/src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py @@ -7,31 +7,9 @@ class BandsPdosWidget(ipw.VBox): - """ - A widget for plotting band structure and projected density of states (PDOS) data. - - Parameters - ---------- - - bands (optional): A node containing band structure data. - - pdos (optional): A node containing PDOS data. - - Attributes - ---------- - - description: HTML description of the widget. - - dos_atoms_group: Dropdown widget to select the grouping of atoms for PDOS plotting. - - dos_plot_group: Dropdown widget to select the type of PDOS contributions to plot. - - selected_atoms: Text widget to select specific atoms for PDOS plotting. - - update_plot_button: Button widget to update the plot. - - download_button: Button widget to download the data. - - project_bands_box: Checkbox widget to choose whether projected bands should be plotted. - - plot_widget: Plotly widget for band structure and PDOS plot. - - bands_widget: Output widget to display the bandsplot widget. - """ - - def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs): - if bands is None and pdos is None: - raise ValueError("Either bands or pdos must be provided") + """A widget for plotting band structure and projected density of states (PDOS).""" + def __init__(self, model: BandsPdosModel, **kwargs): super().__init__( children=[LoadingWidget("Loading widgets")], **kwargs, @@ -49,9 +27,6 @@ def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs): self.rendered = False - self._model.bands = bands - self._model.pdos = pdos - def render(self): if self.rendered: return diff --git a/src/aiidalab_qe/common/bands_pdos/model.py b/src/aiidalab_qe/common/bands_pdos/model.py index 690ef10c..b7189ffd 100644 --- a/src/aiidalab_qe/common/bands_pdos/model.py +++ b/src/aiidalab_qe/common/bands_pdos/model.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json @@ -6,6 +8,7 @@ import traitlets as tl from IPython.display import display +from aiida import orm from aiida.common.extendeddicts import AttributeDict from aiidalab_qe.common.bands_pdos.utils import ( HTML_TAGS, @@ -79,6 +82,37 @@ def __init__(self, *args, **kwargs): lambda _: self._has_pdos or self.needs_projections_controls, ) + @classmethod + def from_nodes( + cls, + bands_node: orm.WorkChainNode | None = None, + pdos_node: orm.WorkChainNode | None = None, + ): + if not (bands_node or pdos_node): + raise ValueError("At least one of the nodes must be provided") + + if bands_node and bands_node.is_finished_ok: + bands = ( + bands_node.outputs.bands + if "bands" in bands_node.outputs + else bands_node.outputs.bands_projwfc + if "bands_projwfc" in bands_node.outputs + else None + ) + else: + bands = None + + if pdos_node and pdos_node.is_finished_ok: + items = {key: getattr(pdos_node.outputs, key) for key in pdos_node.outputs} + pdos = AttributeDict(items) + else: + pdos = None + + if bands or pdos: + return cls(bands=bands, pdos=pdos) + + raise ValueError("Failed to parse at least one node") + def fetch_data(self): """Fetch the data from the nodes.""" if self.bands: diff --git a/src/aiidalab_qe/common/panel.py b/src/aiidalab_qe/common/panel.py index 223de660..3b1c210e 100644 --- a/src/aiidalab_qe/common/panel.py +++ b/src/aiidalab_qe/common/panel.py @@ -521,7 +521,7 @@ class ResultsModel(PanelModel, HasProcess): @property def has_results(self): - node = self._fetch_child_process_node() + node = self.fetch_child_process_node() return node and node.is_finished_ok def update(self): @@ -531,6 +531,30 @@ def update(self): def update_process_status_notification(self): self.process_status_notification = self._get_child_process_status() + def fetch_child_process_node(self, child="this") -> orm.ProcessNode | None: + if not self.process_uuid: + return + child = child.lower() + uuid = getattr(self, f"_{child}_process_uuid") + label = getattr(self, f"_{child}_process_label") + if not uuid: + uuid = ( + orm.QueryBuilder() + .append( + orm.WorkChainNode, + filters={"uuid": self.process_uuid}, + tag="root_process", + ) + .append( + orm.WorkChainNode, + filters={"attributes.process_label": label}, + project="uuid", + with_incoming="root_process", + ) + .first(flat=True) + ) + return orm.load_node(uuid) if uuid else None # type: ignore + def _get_child_process_status(self, child="this"): state, exit_message = self._get_child_state_and_exit_message(child) status = state.upper() @@ -546,7 +570,7 @@ def _get_child_process_status(self, child="this"): def _get_child_state_and_exit_message(self, child="this"): if not ( - (node := self._fetch_child_process_node(child)) + (node := self.fetch_child_process_node(child)) and hasattr(node, "process_state") and node.process_state ): @@ -556,36 +580,12 @@ def _get_child_state_and_exit_message(self, child="this"): return node.process_state.value, None def _get_child_outputs(self, child="this"): - if not (node := self._fetch_child_process_node(child)): + if not (node := self.fetch_child_process_node(child)): outputs = super().outputs child = child if child != "this" else self.identifier return getattr(outputs, child) if child in outputs else AttributeDict({}) return AttributeDict({key: getattr(node.outputs, key) for key in node.outputs}) - def _fetch_child_process_node(self, child="this") -> orm.ProcessNode | None: - if not self.process_uuid: - return - child = child.lower() - uuid = getattr(self, f"_{child}_process_uuid") - label = getattr(self, f"_{child}_process_label") - if not uuid: - uuid = ( - orm.QueryBuilder() - .append( - orm.WorkChainNode, - filters={"uuid": self.process_uuid}, - tag="root_process", - ) - .append( - orm.WorkChainNode, - filters={"attributes.process_label": label}, - project="uuid", - with_incoming="root_process", - ) - .first(flat=True) - ) - return orm.load_node(uuid) if uuid else None # type: ignore - RM = t.TypeVar("RM", bound=ResultsModel) diff --git a/src/aiidalab_qe/plugins/bands/result/model.py b/src/aiidalab_qe/plugins/bands/result/model.py index 1f9df2bb..9124197a 100644 --- a/src/aiidalab_qe/plugins/bands/result/model.py +++ b/src/aiidalab_qe/plugins/bands/result/model.py @@ -6,14 +6,3 @@ class BandsResultsModel(ResultsModel): identifier = "bands" _this_process_label = "BandsWorkChain" - - def get_bands_node(self): - outputs = self._get_child_outputs() - if "bands" in outputs: - return outputs.bands - elif "bands_projwfc" in outputs: - return outputs.bands_projwfc - else: - # If neither 'bands' nor 'bands_projwfc' exist, use 'bands_output' itself - # This is the case for compatibility with older versions of the plugin - return outputs diff --git a/src/aiidalab_qe/plugins/bands/result/result.py b/src/aiidalab_qe/plugins/bands/result/result.py index ad940024..aec5dd7c 100644 --- a/src/aiidalab_qe/plugins/bands/result/result.py +++ b/src/aiidalab_qe/plugins/bands/result/result.py @@ -8,8 +8,8 @@ class BandsResultsPanel(ResultsPanel[BandsResultsModel]): def _render(self): - bands_node = self._model.get_bands_node() - model = BandsPdosModel() - widget = BandsPdosWidget(model=model, bands=bands_node) + bands_node = self._model.fetch_child_process_node() + model = BandsPdosModel.from_nodes(bands_node=bands_node) + widget = BandsPdosWidget(model=model) widget.render() self.children = [widget] diff --git a/src/aiidalab_qe/plugins/electronic_structure/result/model.py b/src/aiidalab_qe/plugins/electronic_structure/result/model.py index 5f8534f7..acccb030 100644 --- a/src/aiidalab_qe/plugins/electronic_structure/result/model.py +++ b/src/aiidalab_qe/plugins/electronic_structure/result/model.py @@ -17,20 +17,6 @@ class ElectronicStructureResultsModel(ResultsModel): _pdos_process_label = "PdosWorkChain" _pdos_process_uuid = None - def get_pdos_node(self): - return self._get_child_outputs("pdos") - - def get_bands_node(self): - outputs = self._get_child_outputs("bands") - if "bands" in outputs: - return outputs.bands - elif "bands_projwfc" in outputs: - return outputs.bands_projwfc - else: - # If neither 'bands' nor 'bands_projwfc' exist, use 'bands_output' itself - # This is the case for compatibility with older versions of the plugin - return outputs - @property def include(self): return all(identifier in self.properties for identifier in self.identifiers) diff --git a/src/aiidalab_qe/plugins/electronic_structure/result/result.py b/src/aiidalab_qe/plugins/electronic_structure/result/result.py index 87eb8d74..7b3ad62c 100644 --- a/src/aiidalab_qe/plugins/electronic_structure/result/result.py +++ b/src/aiidalab_qe/plugins/electronic_structure/result/result.py @@ -8,9 +8,9 @@ class ElectronicStructureResultsPanel(ResultsPanel[ElectronicStructureResultsModel]): def _render(self): - bands_node = self._model.get_bands_node() - pdos_node = self._model.get_pdos_node() - model = BandsPdosModel() - widget = BandsPdosWidget(model=model, bands=bands_node, pdos=pdos_node) + bands_node = self._model.fetch_child_process_node("bands") + pdos_node = self._model.fetch_child_process_node("pdos") + model = BandsPdosModel.from_nodes(bands_node=bands_node, pdos_node=pdos_node) + widget = BandsPdosWidget(model=model) widget.render() self.children = [widget] diff --git a/src/aiidalab_qe/plugins/pdos/result/model.py b/src/aiidalab_qe/plugins/pdos/result/model.py index c010c0f8..55621ed9 100644 --- a/src/aiidalab_qe/plugins/pdos/result/model.py +++ b/src/aiidalab_qe/plugins/pdos/result/model.py @@ -6,6 +6,3 @@ class PdosResultsModel(ResultsModel): identifier = "pdos" _this_process_label = "PdosWorkChain" - - def get_pdos_node(self): - return self._get_child_outputs() diff --git a/src/aiidalab_qe/plugins/pdos/result/result.py b/src/aiidalab_qe/plugins/pdos/result/result.py index eefa4e70..e8ea3b53 100644 --- a/src/aiidalab_qe/plugins/pdos/result/result.py +++ b/src/aiidalab_qe/plugins/pdos/result/result.py @@ -8,8 +8,8 @@ class PdosResultsPanel(ResultsPanel[PdosResultsModel]): def _render(self): - pdos_node = self._model.get_pdos_node() - model = BandsPdosModel() - widget = BandsPdosWidget(model=model, pdos=pdos_node) + pdos_node = self._model.fetch_child_process_node() + model = BandsPdosModel.from_nodes(pdos_node=pdos_node) + widget = BandsPdosWidget(model=model) widget.render() self.children = [widget]