From b551581c5cb7c2ee7dea288a846e575969fa3235 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Thu, 19 Sep 2024 18:41:28 -0400 Subject: [PATCH] `element_pair_rdfs` plots radial distribution functions (RDFs) for element pairs in a structure (#203) * improve set_plotly_template auto-complete with Literal type * add element_pair_rdfs(structure) -> go.Figure in new pymatviz/rdf.py module * add tests/test_rdf.py * remove ase.Atoms conversion to avoid new pkg dep * show element_pair_rdfs examples in readme * should have used save_and_compress_svg --- assets/element-pair-rdfs-Na8Nb8O24.svg | 1 + assets/element-pair-rdfs-Si16O32.svg | 1 + examples/make_assets/rdf.py | 23 +++ pymatviz/__init__.py | 5 +- pymatviz/rdf.py | 186 +++++++++++++++++++++++++ pymatviz/templates.py | 5 +- readme.md | 11 ++ tests/test_rdf.py | 128 +++++++++++++++++ 8 files changed, 358 insertions(+), 2 deletions(-) create mode 100644 assets/element-pair-rdfs-Na8Nb8O24.svg create mode 100644 assets/element-pair-rdfs-Si16O32.svg create mode 100644 examples/make_assets/rdf.py create mode 100644 pymatviz/rdf.py create mode 100644 tests/test_rdf.py diff --git a/assets/element-pair-rdfs-Na8Nb8O24.svg b/assets/element-pair-rdfs-Na8Nb8O24.svg new file mode 100644 index 00000000..0cdfbc19 --- /dev/null +++ b/assets/element-pair-rdfs-Na8Nb8O24.svg @@ -0,0 +1 @@ +246810052468100510152468100510246810052468100510246810051015Pairwise RDFs - Na8 Nb8 O24r (Å)r (Å)r (Å)g(r)g(r)Na-NaNa-NbNa-ONb-NbNb-OO-O diff --git a/assets/element-pair-rdfs-Si16O32.svg b/assets/element-pair-rdfs-Si16O32.svg new file mode 100644 index 00000000..cc8c8f7c --- /dev/null +++ b/assets/element-pair-rdfs-Si16O32.svg @@ -0,0 +1 @@ +24681005101524681001020246810051015Pairwise RDFs - Si16 O32r (Å)r (Å)r (Å)g(r)O-OO-SiSi-Si diff --git a/examples/make_assets/rdf.py b/examples/make_assets/rdf.py new file mode 100644 index 00000000..ea954e95 --- /dev/null +++ b/examples/make_assets/rdf.py @@ -0,0 +1,23 @@ +from matminer.datasets import load_dataset + +import pymatviz as pmv +from pymatviz.enums import Key + + +pmv.set_plotly_template("pymatviz_white") + +df_phonons = load_dataset("matbench_phonons") + + +# get the 2 largest structures +df_phonons[Key.n_sites] = df_phonons[Key.structure].apply(len) + +# plot element-pair RDFs for each structure +for struct in df_phonons.nlargest(2, Key.n_sites)[Key.structure]: + fig = pmv.element_pair_rdfs(struct, n_bins=100, cutoff=10) + formula = struct.formula + fig.layout.title.update(text=f"Pairwise RDFs - {formula}", x=0.5, y=0.98) + fig.layout.margin = dict(l=40, r=0, t=50, b=0) + + fig.show() + pmv.io.save_and_compress_svg(fig, f"element-pair-rdfs-{formula.replace(' ', '')}") diff --git a/pymatviz/__init__.py b/pymatviz/__init__.py index 1f259f0b..f493a57e 100644 --- a/pymatviz/__init__.py +++ b/pymatviz/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +import builtins from importlib.metadata import PackageNotFoundError, version import matplotlib.pyplot as plt @@ -30,6 +31,7 @@ powerups, process_data, ptable, + rdf, relevance, sankey, scatter, @@ -56,6 +58,7 @@ ptable_lines, ptable_scatters, ) +from pymatviz.rdf import element_pair_rdfs from pymatviz.relevance import precision_recall_curve, roc_curve from pymatviz.sankey import sankey_from_2_df_cols from pymatviz.scatter import ( @@ -94,7 +97,7 @@ pass # package not installed -IS_IPYTHON = hasattr(__builtins__, "__IPYTHON__") +IS_IPYTHON = hasattr(builtins, "__IPYTHON__") # define a sensible order for crystal systems across plots crystal_sys_order = ( diff --git a/pymatviz/rdf.py b/pymatviz/rdf.py new file mode 100644 index 00000000..760629c5 --- /dev/null +++ b/pymatviz/rdf.py @@ -0,0 +1,186 @@ +"""This module calculates and plots pairwise radial distribution functions (RDFs) for +pymatgen structures using plotly. + +The main function, pairwise_rdfs, generates a plotly figure with facets for each +pair of elements in the given structure. It supports customization of cutoff distance, +bin size, specific element pairs to plot, reference line. + +Example usage: + structure = Structure(...) # Create or load a pymatgen Structure + fig = pairwise_rdfs(structure, bin_size=0.1) + fig.show() +""" + +from typing import Any + +import numpy as np +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from pymatgen.core import Structure +from scipy.signal import find_peaks + + +def calculate_rdf( + structure: Structure, + center_species: str, + neighbor_species: str, + cutoff: float, + n_bins: int, +) -> tuple[np.ndarray, np.ndarray]: + """Calculate the radial distribution function (RDF) for a given pair of species. + + The RDF is normalized by the number of pairs and the shell volume density, which + makes the RDF approach 1 for large separations in a homogeneous system. + + Args: + structure (Structure): A pymatgen Structure object. + center_species (str): Symbol of the central species. + neighbor_species (str): Symbol of the neighbor species. + cutoff (float): Maximum distance for RDF calculation. + n_bins (int): Number of bins for RDF calculation. + + Returns: + tuple[np.ndarray, np.ndarray]: Arrays of (radii, g(r)) values. + """ + bin_size = cutoff / n_bins + radii = np.linspace(0, cutoff, n_bins + 1)[1:] + rdf = np.zeros(n_bins) + + center_indices = [ + i for i, site in enumerate(structure) if site.specie.symbol == center_species + ] + neighbor_indices = [ + i for i, site in enumerate(structure) if site.specie.symbol == neighbor_species + ] + + for center_idx in center_indices: + for neighbor_idx in neighbor_indices: + if center_idx != neighbor_idx: + distance = structure.get_distance(center_idx, neighbor_idx) + if distance < cutoff: + rdf[int(distance / bin_size)] += 1 + + # Normalize RDF by the number of center-neighbor pairs and shell volumes + rdf = rdf / (len(center_indices) * len(neighbor_indices)) + shell_volumes = 4 * np.pi * radii**2 * bin_size + rdf = rdf / (shell_volumes / structure.volume) + + return radii, rdf + + +def find_last_significant_peak( + radii: np.ndarray, rdf: np.ndarray, prominence: float = 0.1 +) -> float: + """Find the position of the last significant peak in the RDF.""" + peaks, properties = find_peaks(rdf, prominence=prominence, distance=5) + if peaks.size > 0: + # Sort peaks by prominence and select the last significant one + sorted_peaks = peaks[np.argsort(properties["prominences"])] + return radii[sorted_peaks[-1]] + return radii[-1] + + +def element_pair_rdfs( + structure: Structure, + cutoff: float = 15, + n_bins: int = 75, + bin_size: float | None = None, + element_pairs: list[tuple[str, str]] | None = None, + reference_line: dict[str, Any] | None = None, +) -> go.Figure: + """Generate a plotly figure of pairwise radial distribution functions (RDFs) for + all (or a subset of) element pairs in a structure. + + The RDF is the probability of finding a neighbor at a distance r from a central + atom. Basically a histogram of pair-wise particle distances. + + Args: + structure (Structure): pymatgen Structure. + cutoff (float, optional): Maximum distance for RDF calculation. Default is 15 Å. + n_bins (int, optional): Number of bins for RDF calculation. Default is 75. + bin_size (float, optional): Size of bins for RDF calculation. If specified, it + overrides n_bins. Default is None. + element_pairs (list[tuple[str, str]], optional): Element pairs to plot. + If None, all pairs are plotted. + reference_line (dict, optional): Keywords for reference line at g(r)=1 drawn + with Figure.add_hline(). If None (default), no reference line is drawn. + + Returns: + go.Figure: A plotly figure with facets for each pairwise RDF. + + Raises: + ValueError: If the structure contains no sites, if invalid element pairs are + provided, or if both n_bins and bin_size are specified. + """ + if not structure.sites: + raise ValueError("input structure contains no sites") + + if n_bins != 75 and bin_size is not None: + raise ValueError( + f"Cannot specify both {n_bins=} and {bin_size=}. Pick one or the other." + ) + + uniq_elements = sorted({site.specie.symbol for site in structure}) + element_pairs = element_pairs or [ + (e1, e2) for e1 in uniq_elements for e2 in uniq_elements if e1 <= e2 + ] + element_pairs = sorted(element_pairs) + + if extra_elems := {e1 for e1, _e2 in element_pairs} - set(uniq_elements): + raise ValueError( + f"Elements {extra_elems} in element_pairs are not present in the structure" + ) + + # Calculate pairwise RDFs + if bin_size is not None: + n_bins = int(cutoff / bin_size) + elem_pair_rdfs = { + pair: calculate_rdf(structure, *pair, cutoff, n_bins) for pair in element_pairs + } + + # Determine subplot layout + n_pairs = len(element_pairs) + n_cols = min(3, n_pairs) + n_rows = (n_pairs + n_cols - 1) // n_cols + + # Create the plotly figure with facets + fig = make_subplots( + rows=n_rows, + cols=n_cols, + subplot_titles=[f"{e1}-{e2}" for e1, e2 in element_pairs], + vertical_spacing=0.25 / n_rows, + horizontal_spacing=0.15 / n_cols, + ) + + # Add RDF traces to the figure + for idx, (pair, (radii, rdf)) in enumerate(elem_pair_rdfs.items()): + row, col = divmod(idx, n_cols) + row += 1 + col += 1 + + fig.add_scatter( + x=radii, + y=rdf, + mode="lines", + name=f"{pair[0]}-{pair[1]}", + line=dict(color="royalblue"), + showlegend=False, + row=row, + col=col, + hovertemplate="r = %{x:.2f} Å
g(r) = %{y:.2f}", + ) + + # if one of the last n_col subplots, add x-axis label + if idx >= n_pairs - n_cols: + fig.update_xaxes(title_text="r (Å)", row=row, col=col) + + # Add reference line if specified + if reference_line is not None: + defaults = dict(line_dash="dash", line_color="red") + fig.add_hline(y=1, row=row, col=col, **defaults | reference_line) + + # set subplot height/width and x/y axis labels + fig.update_layout(height=200 * n_rows, width=350 * n_cols) + fig.update_yaxes(title=dict(text="g(r)", standoff=0.1), col=1) + + return fig diff --git a/pymatviz/templates.py b/pymatviz/templates.py index bbfac042..625701da 100644 --- a/pymatviz/templates.py +++ b/pymatviz/templates.py @@ -3,6 +3,7 @@ from __future__ import annotations from importlib.metadata import PackageNotFoundError, version +from typing import Literal import matplotlib.pyplot as plt import plotly.express as px @@ -62,7 +63,9 @@ ) -def set_plotly_template(template: str | go.layout.Template) -> None: +def set_plotly_template( + template: Literal["pymatviz_white", "pymatviz_dark"] | str | go.layout.Template, # noqa: PYI051 +) -> None: """Set the default plotly express and graph objects template. Args: diff --git a/readme.md b/readme.md index 4f652cb7..52575d55 100644 --- a/readme.md +++ b/readme.md @@ -168,6 +168,17 @@ See [`pymatviz/xrd.py`](pymatviz/xrd.py). [xrd-pattern]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern.svg [xrd-pattern-multiple]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern-multiple.svg +## Radial Distribution Functions + +See [`pymatviz/rdf.py`](pymatviz/rdf.py). + +| [`rdf_plot(rdf)`](pymatviz/rdf.py) | [`rdf_plot(rdf, rdf2)`](pymatviz/rdf.py) | +| :--------------------------------: | :--------------------------------------: | +| ![element-pair-rdfs-Si16O32] | ![element-pair-rdfs-Na8Nb8O24] | + +[element-pair-rdfs-Si16O32]: examples/make_assets/element-pair-rdfs-Si16O32.svg +[element-pair-rdfs-Na8Nb8O24]: examples/make_assets/element-pair-rdfs-Na8Nb8O24.svg + ## Uncertainty See [`pymatviz/uncertainty.py`](pymatviz/uncertainty.py). diff --git a/tests/test_rdf.py b/tests/test_rdf.py new file mode 100644 index 00000000..56e36d57 --- /dev/null +++ b/tests/test_rdf.py @@ -0,0 +1,128 @@ +import numpy as np +import plotly.graph_objects as go +import pytest +from pymatgen.core import Lattice, Structure + +from pymatviz.rdf import calculate_rdf, element_pair_rdfs + + +def test_element_pair_rdfs_basic(structures: list[Structure]) -> None: + for structure in structures: + fig = element_pair_rdfs(structure) + assert isinstance(fig, go.Figure) + assert fig.layout.title.text is None + assert fig.layout.showlegend is None + assert fig.layout.yaxis.title.text == "g(r)" + + +def test_element_pair_rdfs_empty_structure() -> None: + empty_structure = Structure(Lattice.cubic(1), [], []) + with pytest.raises(ValueError, match="input structure contains no sites"): + element_pair_rdfs(empty_structure) + + +def test_element_pair_rdfs_invalid_element_pairs(structures: list[Structure]) -> None: + with pytest.raises( + ValueError, + match="Elements .* in element_pairs are not present in the structure", + ): + element_pair_rdfs( + structures[0], element_pairs=[("Zn", "Zn")] + ) # Assuming Zn is not in the structure + + +@pytest.mark.parametrize( + ("param", "values"), + [("cutoff", (5, 10, 15)), ("bin_size", (0.05, 0.1, 0.2))], +) +def test_element_pair_rdfs_cutoff_and_bin_size( + structures: list[Structure], param: str, values: tuple[float, ...] +) -> None: + structure = structures[0] + for value in values: + fig = element_pair_rdfs(structure, **{param: value}) # type: ignore[arg-type] + + # Check that we have the correct number of traces (one for each element pair) + n_elements = len({site.specie.symbol for site in structure}) + expected_traces = n_elements * (n_elements + 1) // 2 + assert ( + len(fig.data) == expected_traces + ), f"Expected {expected_traces} traces, got {len(fig.data)}" + + for trace in fig.data: + if param == "cutoff": + # Check that the x-axis data doesn't exceed the cutoff + assert np.all( + trace.x <= value + ), f"X-axis data exceeds cutoff of {value}" + # Check that the maximum x value is close to the cutoff + assert max(trace.x) == pytest.approx( + value + ), f"Maximum x value {max(trace.x)} not close to cutoff {value}" + elif param == "bin_size": + # Check that the number of bins is approximately correct + default_cutoff = 15 # Assuming default cutoff is 10.0 + expected_bins = int(np.ceil(default_cutoff / value)) + assert ( + abs(len(trace.x) - expected_bins) <= 1 + ), f"Expected around {expected_bins} bins, got {len(trace.x)}" + + +def test_element_pair_rdfs_element_pairs(structures: list[Structure]) -> None: + element_pairs = [("Si", "Si")] + fig = element_pair_rdfs(structures[0], element_pairs=element_pairs) + assert len(fig.data) == len(element_pairs) + assert fig.data[0].name == "Si-Si" + + +def test_element_pair_rdfs_subplot_layout(structures: list[Structure]) -> None: + for structure in structures: + fig = element_pair_rdfs(structure) + n_elements = len({site.specie.symbol for site in structure}) + expected_pairs = n_elements * (n_elements + 1) // 2 + assert len(fig.data) == expected_pairs + assert all(isinstance(trace, go.Scatter) for trace in fig.data) + + +def test_calculate_rdf(structures: list[Structure]) -> None: + for structure in structures: + elements = list({site.specie.symbol for site in structure}) + for e1 in elements: + for e2 in elements: + radii, rdf = calculate_rdf(structure, e1, e2, 10.0, 100) + assert isinstance(radii, np.ndarray) + assert isinstance(rdf, np.ndarray) + assert len(radii) == len(rdf) + assert np.all(rdf >= 0) + + +@pytest.mark.parametrize( + "element_pairs", [[("Si", "Si")], [("Si", "Ru"), ("Pr", "Pr")], None] +) +def test_element_pair_rdfs_custom_element_pairs( + structures: list[Structure], element_pairs: list[tuple[str, str]] | None +) -> None: + structure = structures[1] # Use the structure with Si, Ru, and Pr + fig = element_pair_rdfs(structure, element_pairs=element_pairs) + expected_pairs = sorted( + element_pairs + if element_pairs + else [ + (e1, e2) + for e1 in structure.symbol_set + for e2 in structure.symbol_set + if e1 <= e2 + ] + ) + assert len(fig.data) == len(expected_pairs) + for trace, pair in zip(fig.data, expected_pairs, strict=True): + assert trace.name == f"{pair[0]}-{pair[1]}" + + +def test_element_pair_rdfs_consistency(structures: list[Structure]) -> None: + for structure in structures: + fig1 = element_pair_rdfs(structure, cutoff=5.0, bin_size=0.1) + fig2 = element_pair_rdfs(structure, cutoff=5.0, bin_size=0.1) + for trace1, trace2 in zip(fig1.data, fig2.data, strict=True): + assert np.allclose(trace1.x, trace2.x) + assert np.allclose(trace1.y, trace2.y)