From 8a44d47f985da28574605e773c8e2782d32b21e3 Mon Sep 17 00:00:00 2001 From: Jimmy Shen <14003693+jmmshn@users.noreply.github.com> Date: Tue, 21 May 2024 14:45:17 -0700 Subject: [PATCH] Plotly Support (#195) * plot fed * plot fef * plot fed * plot fed * plot fed * plot fed * remove uid * remove uid * moved phase plot * moved phase plot * moved phase plot * moved phase plot * moved phase plot * plot tests plot tests plot tests plot tests plot tests * plot tests * plot tests --- pymatgen/analysis/defects/plotting/phases.py | 142 +------- pymatgen/analysis/defects/plotting/thermo.py | 337 +++++++++++++++++++ pymatgen/analysis/defects/plotting/utils.py | 28 ++ pymatgen/analysis/defects/thermo.py | 33 +- pymatgen/analysis/defects/utils.py | 8 +- tests/conftest.py | 46 ++- tests/plotting/__init__.py | 0 tests/plotting/test_thermo.py | 13 + tests/test_thermo.py | 2 + 9 files changed, 457 insertions(+), 152 deletions(-) create mode 100644 pymatgen/analysis/defects/plotting/thermo.py create mode 100644 pymatgen/analysis/defects/plotting/utils.py create mode 100644 tests/plotting/__init__.py create mode 100644 tests/plotting/test_thermo.py diff --git a/pymatgen/analysis/defects/plotting/phases.py b/pymatgen/analysis/defects/plotting/phases.py index ecd008c1..2f15a816 100644 --- a/pymatgen/analysis/defects/plotting/phases.py +++ b/pymatgen/analysis/defects/plotting/phases.py @@ -1,142 +1,6 @@ """Plotting functions for competing phases.""" -from __future__ import annotations +# Contents moved to pymatgen.analysis.defects.plotting.thermo +from .thermo import plot_chempot_2d -import logging -from typing import TYPE_CHECKING - -from matplotlib import pyplot as plt -from matplotlib.patches import Polygon -from pymatgen.util.string import latexify -from scipy.spatial import ConvexHull - -if TYPE_CHECKING: - from matplotlib.axes import Axes - from pymatgen.analysis.defects.thermo import FormationEnergyDiagram - from pymatgen.core import Element - -# check if labellines is installed -try: - from labellines import labelLines -except ImportError: - - def labelLines(*args, **kwargs) -> None: # noqa: ARG001, ANN002 - """Dummy function if labellines is not installed.""" - - -__author__ = "Jimmy Shen" -__copyright__ = "Copyright 2022, The Materials Project" -__maintainer__ = "Jimmy Shen @jmmshn" -__date__ = "July 2023" - -logger = logging.getLogger(__name__) - - -def plot_chempot_2d( - fed: FormationEnergyDiagram, - x_element: Element, - y_element: Element, - ax: Axes | None = None, - min_mu: float = -5.0, - label_lines: bool = False, - x_vals: list[float] | None = None, - label_fontsize: int = 12, -) -> None: - """Plot the chemical potential diagram for two elements. - - Args: - fed: - The formation energy diagram. - x_element: - The element to use for the x-axis. - y_element: - The element to use for the y-axis. - ax: - The matplotlib axes to plot on. If None, a new figure will be created. - min_mu: - The minimum chemical potential to plot. - label_lines: - Whether to label the lines with the competing phases. Requires Labellines to be installed. - x_vals: - The x position of the line labels. If None, defaults will be used. - label_fontsize: - The fontsize for the line labels. - """ - PLOT_PADDING = 0.1 - ax = ax or plt.gca() - hull2d = _convex_hull_2d( - fed.chempot_limits, - x_element=x_element, - y_element=y_element, - competing_phases=fed.competing_phases, - ) - x_min = float("inf") - y_min = float("inf") - clip_path = [] - for p1, p2, phase in hull2d: - p_txt = ", ".join(map(latexify, phase.keys())) - ax.axline(p1, p2, label=p_txt, color="k") - ax.scatter(p1[0], p1[1], color="k") - x_m_ = p1[0] if p1[0] > min_mu else float("inf") - y_m_ = p1[1] if p1[1] > min_mu else float("inf") - x_min = min(x_min, x_m_) - y_min = min(y_min, y_m_) - clip_path.append(p1) - - patch = Polygon( - clip_path, - closed=True, - ) - ax.add_patch(patch) - - ax.set_xlabel(rf"$\Delta\mu_{{{x_element}}}$ (eV)") - ax.set_ylabel(rf"$\Delta\mu_{{{y_element}}}$ (eV)") - ax.set_xlim(x_min - PLOT_PADDING, 0 + PLOT_PADDING) - ax.set_ylim(y_min - PLOT_PADDING, 0 + PLOT_PADDING) - if label_lines: - labelLines(ax.get_lines(), align=False, xvals=x_vals, fontsize=label_fontsize) - - -def _convex_hull_2d( - points: list[dict], - x_element: Element, - y_element: Element, - competing_phases: list | None = None, -) -> list: - """Compute the convex hull of a set of points in 2D. - - Args: - points: - A list of dictionaries with keys "x" and "y" and values as floats. - x_element: - The element to use for the x-axis. - y_element: - The element to use for the y-axis. - tol: - The tolerance for determining if two points are the same in the 2D plane. - competing_phases: - A list of competing phases for each point. - - Returns: - A list of dictionaries with keys "x" and "y" that form the vertices of the - convex hull. - """ - if competing_phases is None: - competing_phases = [None] * len(points) - xy_points = [(pt[x_element], pt[y_element]) for pt in points] - hull = ConvexHull(xy_points) - xy_hull = [xy_points[i] for i in hull.vertices] - - def _get_line_data(i1: int, i2: int) -> tuple: - cp1 = competing_phases[hull.vertices[i1]] - cp2 = competing_phases[hull.vertices[i2]] - shared_keys = cp1.keys() & cp2.keys() - shared_phase = {k: cp1[k] for k in shared_keys} - return xy_hull[i1], xy_hull[i2], shared_phase - - # return all pairs of points: - pt_and_phase = [ - _get_line_data(itr - 1, itr) for itr in range(1, len(hull.vertices)) - ] - pt_and_phase.append(_get_line_data(len(hull.vertices) - 1, 0)) - return pt_and_phase +__all__ = ["plot_chempot_2d"] diff --git a/pymatgen/analysis/defects/plotting/thermo.py b/pymatgen/analysis/defects/plotting/thermo.py new file mode 100644 index 00000000..ee372a8a --- /dev/null +++ b/pymatgen/analysis/defects/plotting/thermo.py @@ -0,0 +1,337 @@ +"""Plotting functions for defect thermo properties.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +import plotly.express as px +import plotly.graph_objects as go +from matplotlib import pyplot as plt +from matplotlib.patches import Polygon +from pymatgen.analysis.defects.thermo import group_formation_energy_diagrams +from pymatgen.util.string import latexify +from scipy.spatial import ConvexHull + +from .utils import get_line_style_and_color_sequence + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from pymatgen.analysis.defects.thermo import FormationEnergyDiagram + from pymatgen.core import Element + +# check if labellines is installed +try: + from labellines import labelLines +except ImportError: + + def labelLines(*args, **kwargs) -> None: # noqa: ARG001, ANN002 + """Dummy function if labellines is not installed.""" + + +PLOTLY_COLORS = px.colors.qualitative.T10 +PLOTLY_STYLES = ["solid", "dot", "dash", "longdash", "dashdot", "longdashdot"] + +if TYPE_CHECKING: + from collections.abc import Sequence + + from pymatgen.analysis.defects.thermo import FormationEnergyDiagram + + +def _plot_line( + pts: Sequence, + fig: go.Figure, + color: str, + style: str, + name: str, + meta: dict, +) -> None: + """Plot a sequence of x, y points as a line. + + Args: + pts: A sequence of x, y points. + fig: A plotly figure object. + color: The color of the line. + style: The style of the line. + name: The name of the line. + meta: A dictionary of metadata. + + Returns: + None, modifies the fig object in place. + """ + x_pos, y_pos = tuple(zip(*pts)) + trace_ = go.Scatter( + x=x_pos, + y=y_pos, + mode="lines+markers", + textposition="top right", + line=dict(color=color, dash=style), + hoverinfo="x", + meta=meta, + name=name, + ) + fig.add_trace(trace_) + + +def _label_lines(fig: go.Figure, name: str, x_anno: float, color: str) -> None: + """Label the lines in the figure. + + Args: + fig: A plotly figure object. + name: The unique identifier of the line. + x_anno: The x-coordinate of the annotation. + color: The color of the annotation. + + Returns: + None, modifies the fig object in place. + """ + for trace_ in fig.select_traces(selector={"name": name}): + x_pos, y_pos = trace_.x, trace_.y + y_anno = np.interp(x=x_anno, xp=x_pos, fp=y_pos) + fig.add_annotation( + x=x_anno, + y=y_anno, + text=trace_.name, + font=dict(color=color), + arrowcolor=color, + bgcolor="white", + bordercolor=color, + ) + + +def _label_slopes(fig: go.Figure) -> None: + """Label the slopes of the lines in the figure. + + Only labels lines that have the meta attribute 'formation_energy'. + + Args: + fig: A plotly figure object. + """ + for data_ in filter(lambda x: x.meta.get("formation_energy_plot", False), fig.data): + transitions_arr_ = np.array(tuple(zip(data_.x, data_.y))) + diff_arr = transitions_arr_[1:] - transitions_arr_[:-1] + slopes = tuple( + int(slope) for slope in np.round(diff_arr[:, 1] / diff_arr[:, 0]) + ) + pos = (transitions_arr_[:-1] + transitions_arr_[1:]) / 2.0 + x_pos, y_pos = tuple(zip(*pos)) + fig.add_trace( + go.Scatter( + x=x_pos, + y=y_pos, + text=slopes, + mode="text", + textposition="top center", + hoverinfo="skip", + textfont=dict(color=data_.line.color), + name=f"{data_.name}:slope", + showlegend=False, + ) + ) + + +def plot_formation_energy_diagrams( + feds: Sequence[FormationEnergyDiagram], chempot: dict | None = None +) -> go.Figure: + """Plot formation energy diagrams for a sequence of formation energy diagrams. + + Args: + feds: A sequence of formation energy diagrams. + chempot: A dictionary of chemical potentials. + + Returns: + A plotly figure object. + """ + fig = go.Figure() + plot_data = get_plot_data(feds, chempot) + + for name, data in plot_data.items(): + _plot_line( + pts=data["fed"].get_transitions(data["chempot"]), + fig=fig, + name=name, + color=data["color"], + style=data["style"], + meta={"formation_energy_plot": True}, + ) + _label_lines(fig=fig, name=name, x_anno=data["x_anno"], color=data["color"]) + + _label_slopes(fig) + + fig.update_layout( + title="Formation Energy Diagrams", + xaxis_title="Fermi Level (eV)", + yaxis_title="Formation Energy (eV)", + template="plotly_white", + font_family="Helvetica", + xaxis=dict(showgrid=False), + yaxis=dict(showgrid=False), + showlegend=False, + ) + return fig + + +def get_plot_data( + feds: Sequence[FormationEnergyDiagram], chempot: dict | None = None +) -> dict: + """Get the plot data for a sequence of formation energy diagrams. + + Args: + feds: A sequence of formation energy diagrams. + chempot: A dictionary of chemical potentials. + + Returns: + A dictionary of plot data. + - key: The unique identifier (just the unique name from group_formation_energy_diagrams). + - value: A dictionary with the following keys: + - fed: The formation energy diagram. + - style: The style of the line. + - color: The color of the line. + - x_anno: The x-coordinate of the annotation. + - chempot: The chemical potentials used to generate the transitions. + + """ + x_annos_ = np.linspace(0, feds[0].band_gap, len(feds) + 1, endpoint=True) + x_annos_ += (x_annos_[1] - x_annos_[0]) / 2 + x_annos_ = x_annos_[:-1] + + # Group formation energy diagrams by unique name + grouped_feds = list(group_formation_energy_diagrams(feds)) + plot_data = dict() + num_feds = len(grouped_feds) + for (name_, fed), color, x_anno in zip( + grouped_feds, + get_line_style_and_color_sequence(PLOTLY_COLORS, PLOTLY_STYLES), + x_annos_, + ): + if chempot is None: + cation_el_ = fed.chempot_diagram.elements[0] + chempot_ = fed.get_chempots(rich_element=cation_el_) + else: + chempot_ = chempot + plot_data[name_] = dict( + fed=fed, + style=color[0], + color=color[1], + x_anno=x_anno, + chempot=chempot_, + ) + + if len(plot_data) != num_feds: + msg = "Duplicate Name found in formation energy diagrams. " + raise ValueError( + msg, + "This should not happen since each unique defect should have a unique Name.", + ) + + return plot_data + + +def plot_chempot_2d( + fed: FormationEnergyDiagram, + x_element: Element, + y_element: Element, + ax: Axes | None = None, + min_mu: float = -5.0, + label_lines: bool = False, + x_vals: list[float] | None = None, + label_fontsize: int = 12, +) -> None: + """Plot the chemical potential diagram for two elements. + + Args: + fed: + The formation energy diagram. + x_element: + The element to use for the x-axis. + y_element: + The element to use for the y-axis. + ax: + The matplotlib axes to plot on. If None, a new figure will be created. + min_mu: + The minimum chemical potential to plot. + label_lines: + Whether to label the lines with the competing phases. Requires Labellines to be installed. + x_vals: + The x position of the line labels. If None, defaults will be used. + label_fontsize: + The fontsize for the line labels. + """ + PLOT_PADDING = 0.1 + ax = ax or plt.gca() + hull2d = _convex_hull_2d( + fed.chempot_limits, + x_element=x_element, + y_element=y_element, + competing_phases=fed.competing_phases, + ) + x_min = float("inf") + y_min = float("inf") + clip_path = [] + for p1, p2, phase in hull2d: + p_txt = ", ".join(map(latexify, phase.keys())) + ax.axline(p1, p2, label=p_txt, color="k") + ax.scatter(p1[0], p1[1], color="k") + x_m_ = p1[0] if p1[0] > min_mu else float("inf") + y_m_ = p1[1] if p1[1] > min_mu else float("inf") + x_min = min(x_min, x_m_) + y_min = min(y_min, y_m_) + clip_path.append(p1) + + patch = Polygon( + clip_path, + closed=True, + ) + ax.add_patch(patch) + + ax.set_xlabel(rf"$\Delta\mu_{{{x_element}}}$ (eV)") + ax.set_ylabel(rf"$\Delta\mu_{{{y_element}}}$ (eV)") + ax.set_xlim(x_min - PLOT_PADDING, 0 + PLOT_PADDING) + ax.set_ylim(y_min - PLOT_PADDING, 0 + PLOT_PADDING) + if label_lines: + labelLines(ax.get_lines(), align=False, xvals=x_vals, fontsize=label_fontsize) + + +def _convex_hull_2d( + points: list[dict], + x_element: Element, + y_element: Element, + competing_phases: list | None = None, +) -> list: + """Compute the convex hull of a set of points in 2D. + + Args: + points: + A list of dictionaries with keys "x" and "y" and values as floats. + x_element: + The element to use for the x-axis. + y_element: + The element to use for the y-axis. + tol: + The tolerance for determining if two points are the same in the 2D plane. + competing_phases: + A list of competing phases for each point. + + Returns: + A list of dictionaries with keys "x" and "y" that form the vertices of the + convex hull. + """ + if competing_phases is None: + competing_phases = [None] * len(points) + xy_points = [(pt[x_element], pt[y_element]) for pt in points] + hull = ConvexHull(xy_points) + xy_hull = [xy_points[i] for i in hull.vertices] + + def _get_line_data(i1: int, i2: int) -> tuple: + cp1 = competing_phases[hull.vertices[i1]] + cp2 = competing_phases[hull.vertices[i2]] + shared_keys = cp1.keys() & cp2.keys() + shared_phase = {k: cp1[k] for k in shared_keys} + return xy_hull[i1], xy_hull[i2], shared_phase + + # return all pairs of points: + pt_and_phase = [ + _get_line_data(itr - 1, itr) for itr in range(1, len(hull.vertices)) + ] + pt_and_phase.append(_get_line_data(len(hull.vertices) - 1, 0)) + return pt_and_phase diff --git a/pymatgen/analysis/defects/plotting/utils.py b/pymatgen/analysis/defects/plotting/utils.py new file mode 100644 index 00000000..a4eab023 --- /dev/null +++ b/pymatgen/analysis/defects/plotting/utils.py @@ -0,0 +1,28 @@ +"""Plotting utils.""" + +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Generator, Sequence + + +def get_line_style_and_color_sequence( + colors: Sequence, styles: Sequence +) -> Generator[tuple[str, str], None, None]: + """Get a generator for colors and styles. + + Create an iterator that will cycle through the colors and styles. + + Args: + colors: List of colors to use. + styles: List of styles to use. + + Returns: + Generator of (style, color) tuples + """ + for style in itertools.cycle(styles): + for color in itertools.cycle(colors): + yield style, color diff --git a/pymatgen/analysis/defects/thermo.py b/pymatgen/analysis/defects/thermo.py index 27c42cb8..98028a7f 100644 --- a/pymatgen/analysis/defects/thermo.py +++ b/pymatgen/analysis/defects/thermo.py @@ -11,6 +11,7 @@ import numpy as np from matplotlib import pyplot as plt +from monty.dev import deprecated from monty.json import MSONable from pymatgen.analysis.chempot_diagram import ChemicalPotentialDiagram from pymatgen.analysis.defects.core import Defect, NamedDefect @@ -212,6 +213,13 @@ def get_summary_dict(self) -> dict: res.update(corrections_d) return res + @property + def defect_chemsys(self) -> str: + """Get the chemical system of the defect.""" + return "-".join( + sorted({el.symbol for el in self.defect.defect_structure.elements}) + ) + @dataclass class FormationEnergyDiagram(MSONable): @@ -535,11 +543,21 @@ def competing_phases(self) -> list[dict[str, ComputedEntry]]: res.append(competing_phases) return res + @property + def bulk_formula(self) -> str: + """Get the bulk formula.""" + return self.defect_entries[0].defect.structure.composition.reduced_formula + @property def defect(self) -> Defect: """Get the defect that this FormationEnergyDiagram represents.""" return self.defect_entries[0].defect + @property + def defect_chemsys(self) -> str: + """Get the chemical system of the defect.""" + return self.defect_entries[0].defect_chemsys + def _get_lines(self, chempots: dict) -> list[tuple[float, float]]: """Get the lines for the formation energy diagram. @@ -852,21 +870,20 @@ def _get_hash_no_structure(entry: DefectEntry) -> tuple[str, str]: def group_formation_energy_diagrams( - feds: list[FormationEnergyDiagram], + feds: Sequence[FormationEnergyDiagram], sm: StructureMatcher = None, ) -> Generator[tuple[str | None, FormationEnergyDiagram], None, None]: """Group formation energy diagrams by their representation. - First by name then by structure. + First by name then by structure. Note, this function assumes that the defects + are for the same host structure. Args: feds: list of formation energy diagrams sm: StructureMatcher to use for grouping Returns: - If combine_diagrams is True, generator of (name, combined formation energy diagram) tuples. - If combine_diagrams is False, generator of (name, list of formation energy diagrams) tuples. - + Generator of (name, combined formation energy diagram) tuples. """ if sm is None: sm = StructureMatcher(comparator=ElementComparator()) @@ -1135,6 +1152,10 @@ def fermi_dirac(energy: float, temperature: float) -> float: return 1.0 / (1.0 + np.exp((energy) / (boltzman_eV_K * temperature))) +@deprecated( + message="Plotting functions will be moved to the the plotting module. " + "To integrate better with MP website, we will use the Plotly library for plotting." +) def plot_formation_energy_diagrams( formation_energy_diagrams: FormationEnergyDiagram | list[FormationEnergyDiagram] @@ -1332,7 +1353,7 @@ def plot_formation_energy_diagrams( if show_legend: lg = axis.get_legend() if lg: - handle, leg = lg.legendHandles, [txt._text for txt in lg.texts] + handle, leg = lg.legend_handles, [txt._text for txt in lg.texts] else: handle, leg = [], [] diff --git a/pymatgen/analysis/defects/utils.py b/pymatgen/analysis/defects/utils.py index e6ad31cf..2e41bbdc 100644 --- a/pymatgen/analysis/defects/utils.py +++ b/pymatgen/analysis/defects/utils.py @@ -30,7 +30,7 @@ from scipy.spatial.distance import squareform if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Generator, Sequence from pathlib import Path from numpy import typing as npt @@ -263,7 +263,7 @@ def get_zfile( raise FileNotFoundError(msg) -def generic_group_labels(list_in: list, comp: Callable = operator.eq) -> list[int]: +def generic_group_labels(list_in: Sequence, comp: Callable = operator.eq) -> list[int]: """Group a list of unsortable objects. Args: @@ -1053,7 +1053,7 @@ class CorrectionResult(MSONable): def _group_docs_by_structure( - docs: list, sm: StructureMatcher, get_structure: Callable + docs: Sequence, sm: StructureMatcher, get_structure: Callable ) -> Generator[list, None, None]: """Group docs by structure. @@ -1074,7 +1074,7 @@ def _group_docs_by_structure( def group_docs( - docs: list, + docs: Sequence, sm: StructureMatcher, get_structure: Callable, get_hash: Callable | None = None, diff --git a/tests/conftest.py b/tests/conftest.py index 68733e1b..330def5a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,9 @@ import pytest from monty.serialization import loadfn from pymatgen.analysis.defects.core import PeriodicSite, Substitution -from pymatgen.analysis.defects.thermo import DefectEntry -from pymatgen.core import Structure +from pymatgen.analysis.defects.thermo import DefectEntry, FormationEnergyDiagram +from pymatgen.analysis.phase_diagram import PhaseDiagram +from pymatgen.core import Element, Structure from pymatgen.core.periodic_table import Specie from pymatgen.io.vasp.outputs import WSWQ, Chgcar, Locpot, Procar, Vasprun @@ -47,7 +48,7 @@ def data_Mg_Ga(test_dir): "locpot": Locpot, }, ... - } + }. """ root_dir = test_dir / "Mg_Ga" data = defaultdict(dict) @@ -127,3 +128,42 @@ def v_N_GaN(test_dir): 2: Locpot.from_file(test_dir / "v_N_GaN/q=2/LOCPOT.gz"), }, } + + +@pytest.fixture(scope="session") +def basic_fed( + data_Mg_Ga, defect_entries_and_plot_data_Mg_Ga, stable_entries_Mg_Ga_N +): + bulk_vasprun = data_Mg_Ga["bulk_sc"]["vasprun"] + bulk_bs = bulk_vasprun.get_band_structure() + vbm = bulk_bs.get_vbm()["energy"] + bulk_entry = bulk_vasprun.get_computed_entry(inc_structure=False) + defect_entries, _ = defect_entries_and_plot_data_Mg_Ga + + def_ent_list = list(defect_entries.values()) + # test the constructor with materials project phase diagram + atomic_entries = list( + filter(lambda x: len(x.composition.elements) == 1, stable_entries_Mg_Ga_N) + ) + pd = PhaseDiagram(stable_entries_Mg_Ga_N) + # test the constructor with atomic entries + # this is the one we will use for the rest of the tests + fed = FormationEnergyDiagram.with_atomic_entries( + defect_entries=def_ent_list, + atomic_entries=atomic_entries, + vbm=vbm, + inc_inf_values=False, + phase_diagram=pd, + bulk_entry=bulk_entry, + ) + assert len(fed.chempot_limits) == 3 + + # dataframe conversion + df = fed.as_dataframe() + assert df.shape == (4, 5) + + # test that you can get the Ga-rich chempot + cp = fed.get_chempots(rich_element=Element("Ga")) + assert cp[Element("Ga")] == pytest.approx(0, abs=1e-2) + fed.band_gap = 2 + return fed diff --git a/tests/plotting/__init__.py b/tests/plotting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/plotting/test_thermo.py b/tests/plotting/test_thermo.py new file mode 100644 index 00000000..19e5b54d --- /dev/null +++ b/tests/plotting/test_thermo.py @@ -0,0 +1,13 @@ +import pytest +from pymatgen.analysis.defects.plotting.thermo import plot_formation_energy_diagrams, plot_chempot_2d +from pymatgen.core import Element + +def test_fed_plot(basic_fed): + fig = plot_formation_energy_diagrams([basic_fed]) + assert {d_.name for d_ in fig.data} == {'Mg_Ga', 'Mg_Ga:slope'} + +def test_chempot_plot(basic_fed): + plot_chempot_2d(basic_fed, x_element=Element("Mg"), y_element=Element("Ga")) + + + diff --git a/tests/test_thermo.py b/tests/test_thermo.py index 2bce6fe9..47091fe4 100644 --- a/tests/test_thermo.py +++ b/tests/test_thermo.py @@ -150,6 +150,8 @@ def test_formation_energy_diagram_using_bulk_entry(formation_energy_diagram): pd_entries=fed.pd_entries, ) assert len(fed.chempot_limits) == 3 + assert fed.defect_chemsys == "Ga-Mg-N" + assert fed.bulk_formula == "GaN" def test_formation_energy_diagram_shape_fixed(formation_energy_diagram):