Skip to content

Commit

Permalink
update inspect from voila review
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Mar 13, 2022
1 parent b1c9c16 commit a3545cf
Show file tree
Hide file tree
Showing 10 changed files with 447 additions and 293 deletions.
2 changes: 2 additions & 0 deletions aiidalab_sssp/inspect/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
__version__ = "0.0.1"
57 changes: 43 additions & 14 deletions aiidalab_sssp/plot.py → aiidalab_sssp/inspect/plot_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,42 @@
import random

import matplotlib.pyplot as plt
import numpy as np


def cmap(label: str) -> str:
"""Return RGB string of color for given standard psp label"""
_, pp_family, pp_z, pp_type, pp_version = label.split("/")

if pp_family == "sg15" and pp_version == "v1.0":
return "#000000"

if pp_family == "sg15" and pp_version == "v1.2":
return "#708090"

if pp_family == "gbrv" and pp_version == "v1":
return "#4682B4"

if pp_family == "psl" and pp_version == "v1.0.0" and pp_type == "us":
return "#F50E02"

if pp_family == "psl" and pp_version == "v1.0.0" and pp_type == "paw":
return "#2D8B00F"

if pp_family == "dojo" and pp_version == "v04":
return "#F9A501"

# TODO: more mapping
# if a unknow type generate random color based on ascii sum
ascn = sum([ord(c) for c in label])
random.seed(ascn)
return "#%06x" % random.randint(0, 0xFFFFFF)


def delta_measure_hist(pseudos: dict, measure_type):

px = 1 / plt.rcParams["figure.dpi"] # pixel in inches
fig, ax = plt.subplots(1, 1, figsize=(1024 * px, 360 * px))
cmap = plt.get_cmap("tab20")
NUM_COLOR = 20
structures = ["X", "XO", "X2O", "XO3", "X2O", "X2O3", "X2O5"]
num_structures = len(structures)

Expand All @@ -23,7 +52,7 @@ def delta_measure_hist(pseudos: dict, measure_type):
ylabel = "Δ -factor"
elif measure_type == "nv_delta":
keyname = "rel_errors_vec_length"
ylabel = "νΔ -factor"
ylabel = "ν -factor"

for i, (label, output) in enumerate(pseudos.items()):
N = len(structures)
Expand All @@ -44,7 +73,13 @@ def delta_measure_hist(pseudos: dict, measure_type):
out_label = f"{pp_z}/{pp_type}({pp_family}-{pp_version})"

ax.bar(
idx + width * i, y_delta, width, color=cmap(i / NUM_COLOR), label=out_label
idx + width * i,
y_delta,
width,
color=cmap(label),
edgecolor="black",
linewidth=1,
label=out_label,
)
ax.legend()
ax.set_title(f"X={element}")
Expand All @@ -65,10 +100,8 @@ def convergence(pseudos: dict, wf_name, measure_name, ylabel, threshold=None):
fig, (ax1, ax2) = plt.subplots(
1, 2, gridspec_kw={"width_ratios": [2, 1]}, figsize=(1024 * px, 360 * px)
)
cmap = plt.get_cmap("tab20")
NUM_COLOR = 20

for i, (label, output) in enumerate(pseudos.items()):
for label, output in pseudos.items():
# Calculate the avg delta measure value
structures = ["X", "XO", "X2O", "XO3", "X2O", "X2O3", "X2O5"]
lst = []
Expand All @@ -94,18 +127,14 @@ def convergence(pseudos: dict, wf_name, measure_name, ylabel, threshold=None):
wfc_cutoff = res["final_output_parameters"]["wfc_cutoff"]

_, pp_family, pp_z, pp_type, pp_version = label.split("/")
out_label = (
f"{pp_z}/{pp_type}(νΔ={avg_delta:.2f})({pp_family}-{pp_version})"
)
out_label = f"{pp_z}/{pp_type}(ν={avg_delta:.2f})({pp_family}-{pp_version})"

ax1.plot(
x_wfc, y_wfc, marker="o", color=cmap(i / NUM_COLOR), label=out_label
)
ax1.plot(x_wfc, y_wfc, marker="o", color=cmap(label), label=out_label)
ax2.plot(
x_rho,
y_rho,
marker="o",
color=cmap(i / NUM_COLOR),
color=cmap(label),
label=f"cutoff wfc = {wfc_cutoff} Ry",
)
except Exception:
Expand Down
Empty file.
65 changes: 65 additions & 0 deletions aiidalab_sssp/inspect/subwidgets/periodic_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import ipywidgets as ipw
from widget_periodictable import PTableWidget

__all__ = ("PeriodicTable",)


class PeriodicTable(ipw.VBox):
"""Wrapper-widget for PTableWidget"""

def __init__(self, extended: bool = True, **kwargs):
self._disabled = kwargs.get("disabled", False)

self.select_any_all = ipw.Checkbox(
value=False,
description="Structures can include any chosen elements (instead of all)",
indent=False,
layout={"width": "auto"},
disabled=self.disabled,
)
self.ptable = PTableWidget(**kwargs)
self.ptable_container = ipw.VBox(
children=(self.select_any_all, self.ptable),
layout={
"width": "auto",
"height": "auto" if extended else "0px",
"visibility": "visible" if extended else "hidden",
},
)

super().__init__(
children=(self.ptable_container,),
layout=kwargs.get("layout", {}),
)

@property
def value(self) -> dict:
"""Return value for wrapped PTableWidget"""

return not self.select_any_all.value, self.ptable.selected_elements.copy()

@property
def disabled(self) -> None:
"""Disable widget"""
return self._disabled

@disabled.setter
def disabled(self, value: bool) -> None:
"""Disable widget"""
if not isinstance(value, bool):
raise TypeError("disabled must be a boolean")

self.select_any_all.disabled = self.ptable.disabled = value

def reset(self):
"""Reset widget"""
self.select_any_all.value = False
self.ptable.selected_elements = {}

def freeze(self):
"""Disable widget"""
self.disabled = True

def unfreeze(self):
"""Activate widget (in its current state)"""
self.disabled = False
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,19 @@
import traitlets
from IPython.display import clear_output, display

from aiidalab_sssp.plot import convergence, delta_measure_hist
from aiidalab_sssp.inspect.plot_utils import convergence, delta_measure_hist


class PlotDeltaMeasureWidget(ipw.VBox):

selected_pseudos = traitlets.Dict()
selected_pseudos = traitlets.Dict(allow_none=True)

def __init__(self):
# measure button
# self.measure_tab = ipw.Tab(title=['Δ-factor', 'νΔ-factor'])
self.measure_tab = ipw.Tab()
self.measure_tab.set_title(0, "Δ-factor")
self.measure_tab.set_title(1, "νΔ-factor")
self.measure_tab.set_title(0, "ν-factor")
self.measure_tab.set_title(1, "Δ-factor")

# Delta mesure
self.output_delta_measure = ipw.Output()
Expand All @@ -28,23 +28,25 @@ def __init__(self):

@traitlets.observe("selected_pseudos")
def _on_pseudos_change(self, change):
out_delta = ipw.Output()
out_nv_delta = ipw.Output()
with out_delta:
fig = delta_measure_hist(change["new"], "delta")
display(fig)
out_delta = ipw.Output()

if change["new"]:
with out_nv_delta:
fig = delta_measure_hist(change["new"], "nv_delta")
display(fig)

with out_nv_delta:
fig = delta_measure_hist(change["new"], "nv_delta")
display(fig)
with out_delta:
fig = delta_measure_hist(change["new"], "delta")
display(fig)

children = [out_delta, out_nv_delta]
children = [out_nv_delta, out_delta]
self.measure_tab.children = children


class _PlotConvergenBaseWidget(ipw.VBox):

selected_pseudos = traitlets.Dict()
selected_pseudos = traitlets.Dict(allow_none=True)

_WF = "Not implement"
_MEASURE = "Not implement"
Expand All @@ -66,14 +68,15 @@ def _on_pseudos_change(self, change):
with self.output:
clear_output(wait=True)

fig = convergence(
change["new"],
wf_name=self._WF,
measure_name=self._MEASURE,
ylabel=self._YLABEL,
threshold=self._THRESHOLD,
)
display(fig)
if change["new"]:
fig = convergence(
change["new"],
wf_name=self._WF,
measure_name=self._MEASURE,
ylabel=self._YLABEL,
threshold=self._THRESHOLD,
)
display(fig)


class PlotCohesiveEnergyConvergeWidget(_PlotConvergenBaseWidget):
Expand Down
Loading

0 comments on commit a3545cf

Please sign in to comment.