Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add from_nodes class method to instantiate a BandsPdosModel from AiiDA nodes #1037

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 2 additions & 27 deletions src/aiidalab_qe/common/bands_pdos/bandpdoswidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,9 @@


class BandsPdosWidget(ipw.VBox):
"""
A widget for plotting band structure and projected density of states (PDOS) data.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you are removing mos of the Docstring, add it where it should be , so is easy for someone in the future to do changes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but later, once we understand the design better 👍 This is not done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same goes for tests. No need to comment on them not passing. I'll adjust w.r.t the final design.


Parameters
----------
- bands (optional): A node containing band structure data.
- pdos (optional): A node containing PDOS data.

Attributes
----------
- description: HTML description of the widget.
- dos_atoms_group: Dropdown widget to select the grouping of atoms for PDOS plotting.
- dos_plot_group: Dropdown widget to select the type of PDOS contributions to plot.
- selected_atoms: Text widget to select specific atoms for PDOS plotting.
- update_plot_button: Button widget to update the plot.
- download_button: Button widget to download the data.
- project_bands_box: Checkbox widget to choose whether projected bands should be plotted.
- plot_widget: Plotly widget for band structure and PDOS plot.
- bands_widget: Output widget to display the bandsplot widget.
"""

def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs):
if bands is None and pdos is None:
raise ValueError("Either bands or pdos must be provided")
"""A widget for plotting band structure and projected density of states (PDOS)."""

def __init__(self, model: BandsPdosModel, **kwargs):
super().__init__(
children=[LoadingWidget("Loading widgets")],
**kwargs,
Expand All @@ -49,9 +27,6 @@ def __init__(self, model: BandsPdosModel, bands=None, pdos=None, **kwargs):

self.rendered = False

self._model.bands = bands
self._model.pdos = pdos

def render(self):
if self.rendered:
return
Expand Down
34 changes: 34 additions & 0 deletions src/aiidalab_qe/common/bands_pdos/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import base64
import json

Expand All @@ -6,6 +8,7 @@
import traitlets as tl
from IPython.display import display

from aiida import orm
from aiida.common.extendeddicts import AttributeDict
from aiidalab_qe.common.bands_pdos.utils import (
HTML_TAGS,
Expand Down Expand Up @@ -79,6 +82,37 @@ def __init__(self, *args, **kwargs):
lambda _: self._has_pdos or self.needs_projections_controls,
)

@classmethod
def from_nodes(
cls,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should include some doc string here, to explain why these decisions, and that the bands conditionals is to guarantee backwards compatibility

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Once the design is settled, docstrings 👍

bands_node: orm.WorkChainNode | None = None,
pdos_node: orm.WorkChainNode | None = None,
):
if not (bands_node or pdos_node):
raise ValueError("At least one of the nodes must be provided")

if bands_node and bands_node.is_finished_ok:
bands = (
bands_node.outputs.bands
if "bands" in bands_node.outputs
else bands_node.outputs.bands_projwfc
if "bands_projwfc" in bands_node.outputs
else None
)
else:
bands = None

if pdos_node and pdos_node.is_finished_ok:
items = {key: getattr(pdos_node.outputs, key) for key in pdos_node.outputs}
pdos = AttributeDict(items)
else:
pdos = None

if bands or pdos:
return cls(bands=bands, pdos=pdos)

raise ValueError("Failed to parse at least one node")

def fetch_data(self):
"""Fetch the data from the nodes."""
if self.bands:
Expand Down
54 changes: 27 additions & 27 deletions src/aiidalab_qe/common/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ class ResultsModel(PanelModel, HasProcess):

@property
def has_results(self):
node = self._fetch_child_process_node()
node = self.fetch_child_process_node()
return node and node.is_finished_ok

def update(self):
Expand All @@ -531,6 +531,30 @@ def update(self):
def update_process_status_notification(self):
self.process_status_notification = self._get_child_process_status()

def fetch_child_process_node(self, child="this") -> orm.ProcessNode | None:
if not self.process_uuid:
return
child = child.lower()
uuid = getattr(self, f"_{child}_process_uuid")
label = getattr(self, f"_{child}_process_label")
if not uuid:
uuid = (
orm.QueryBuilder()
.append(
orm.WorkChainNode,
filters={"uuid": self.process_uuid},
tag="root_process",
)
.append(
orm.WorkChainNode,
filters={"attributes.process_label": label},
project="uuid",
with_incoming="root_process",
)
.first(flat=True)
)
return orm.load_node(uuid) if uuid else None # type: ignore

def _get_child_process_status(self, child="this"):
state, exit_message = self._get_child_state_and_exit_message(child)
status = state.upper()
Expand All @@ -546,7 +570,7 @@ def _get_child_process_status(self, child="this"):

def _get_child_state_and_exit_message(self, child="this"):
if not (
(node := self._fetch_child_process_node(child))
(node := self.fetch_child_process_node(child))
and hasattr(node, "process_state")
and node.process_state
):
Expand All @@ -556,36 +580,12 @@ def _get_child_state_and_exit_message(self, child="this"):
return node.process_state.value, None

def _get_child_outputs(self, child="this"):
if not (node := self._fetch_child_process_node(child)):
if not (node := self.fetch_child_process_node(child)):
outputs = super().outputs
child = child if child != "this" else self.identifier
return getattr(outputs, child) if child in outputs else AttributeDict({})
return AttributeDict({key: getattr(node.outputs, key) for key in node.outputs})

def _fetch_child_process_node(self, child="this") -> orm.ProcessNode | None:
if not self.process_uuid:
return
child = child.lower()
uuid = getattr(self, f"_{child}_process_uuid")
label = getattr(self, f"_{child}_process_label")
if not uuid:
uuid = (
orm.QueryBuilder()
.append(
orm.WorkChainNode,
filters={"uuid": self.process_uuid},
tag="root_process",
)
.append(
orm.WorkChainNode,
filters={"attributes.process_label": label},
project="uuid",
with_incoming="root_process",
)
.first(flat=True)
)
return orm.load_node(uuid) if uuid else None # type: ignore


RM = t.TypeVar("RM", bound=ResultsModel)

Expand Down
11 changes: 0 additions & 11 deletions src/aiidalab_qe/plugins/bands/result/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,3 @@ class BandsResultsModel(ResultsModel):
identifier = "bands"

_this_process_label = "BandsWorkChain"

def get_bands_node(self):
outputs = self._get_child_outputs()
if "bands" in outputs:
return outputs.bands
elif "bands_projwfc" in outputs:
return outputs.bands_projwfc
else:
# If neither 'bands' nor 'bands_projwfc' exist, use 'bands_output' itself
# This is the case for compatibility with older versions of the plugin
return outputs
6 changes: 3 additions & 3 deletions src/aiidalab_qe/plugins/bands/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

class BandsResultsPanel(ResultsPanel[BandsResultsModel]):
def _render(self):
bands_node = self._model.get_bands_node()
model = BandsPdosModel()
widget = BandsPdosWidget(model=model, bands=bands_node)
bands_node = self._model.fetch_child_process_node()
model = BandsPdosModel.from_nodes(bands_node=bands_node)
widget = BandsPdosWidget(model=model)
widget.render()
self.children = [widget]
14 changes: 0 additions & 14 deletions src/aiidalab_qe/plugins/electronic_structure/result/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,6 @@ class ElectronicStructureResultsModel(ResultsModel):
_pdos_process_label = "PdosWorkChain"
_pdos_process_uuid = None

def get_pdos_node(self):
return self._get_child_outputs("pdos")

def get_bands_node(self):
outputs = self._get_child_outputs("bands")
if "bands" in outputs:
return outputs.bands
elif "bands_projwfc" in outputs:
return outputs.bands_projwfc
else:
# If neither 'bands' nor 'bands_projwfc' exist, use 'bands_output' itself
# This is the case for compatibility with older versions of the plugin
return outputs

@property
def include(self):
return all(identifier in self.properties for identifier in self.identifiers)
Expand Down
8 changes: 4 additions & 4 deletions src/aiidalab_qe/plugins/electronic_structure/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

class ElectronicStructureResultsPanel(ResultsPanel[ElectronicStructureResultsModel]):
def _render(self):
bands_node = self._model.get_bands_node()
pdos_node = self._model.get_pdos_node()
model = BandsPdosModel()
widget = BandsPdosWidget(model=model, bands=bands_node, pdos=pdos_node)
bands_node = self._model.fetch_child_process_node("bands")
pdos_node = self._model.fetch_child_process_node("pdos")
model = BandsPdosModel.from_nodes(bands_node=bands_node, pdos_node=pdos_node)
widget = BandsPdosWidget(model=model)
widget.render()
self.children = [widget]
3 changes: 0 additions & 3 deletions src/aiidalab_qe/plugins/pdos/result/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,3 @@ class PdosResultsModel(ResultsModel):
identifier = "pdos"

_this_process_label = "PdosWorkChain"

def get_pdos_node(self):
return self._get_child_outputs()
6 changes: 3 additions & 3 deletions src/aiidalab_qe/plugins/pdos/result/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

class PdosResultsPanel(ResultsPanel[PdosResultsModel]):
def _render(self):
pdos_node = self._model.get_pdos_node()
model = BandsPdosModel()
widget = BandsPdosWidget(model=model, pdos=pdos_node)
pdos_node = self._model.fetch_child_process_node()
model = BandsPdosModel.from_nodes(pdos_node=pdos_node)
widget = BandsPdosWidget(model=model)
widget.render()
self.children = [widget]
Loading