Skip to content

Commit

Permalink
element_pair_rdfs plots radial distribution functions (RDFs) for el…
Browse files Browse the repository at this point in the history
…ement pairs in a structure (#203)

* improve set_plotly_template auto-complete with Literal type

* add element_pair_rdfs(structure) -> go.Figure in new pymatviz/rdf.py module

* add tests/test_rdf.py

* remove ase.Atoms conversion to avoid new pkg dep

* show element_pair_rdfs examples in readme

* should have used save_and_compress_svg
  • Loading branch information
janosh authored Sep 19, 2024
1 parent 24261ca commit b551581
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 2 deletions.
1 change: 1 addition & 0 deletions assets/element-pair-rdfs-Na8Nb8O24.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions assets/element-pair-rdfs-Si16O32.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
23 changes: 23 additions & 0 deletions examples/make_assets/rdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from matminer.datasets import load_dataset

import pymatviz as pmv
from pymatviz.enums import Key


pmv.set_plotly_template("pymatviz_white")

df_phonons = load_dataset("matbench_phonons")


# get the 2 largest structures
df_phonons[Key.n_sites] = df_phonons[Key.structure].apply(len)

# plot element-pair RDFs for each structure
for struct in df_phonons.nlargest(2, Key.n_sites)[Key.structure]:
fig = pmv.element_pair_rdfs(struct, n_bins=100, cutoff=10)
formula = struct.formula
fig.layout.title.update(text=f"Pairwise RDFs - {formula}", x=0.5, y=0.98)
fig.layout.margin = dict(l=40, r=0, t=50, b=0)

fig.show()
pmv.io.save_and_compress_svg(fig, f"element-pair-rdfs-{formula.replace(' ', '')}")
5 changes: 4 additions & 1 deletion pymatviz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import builtins
from importlib.metadata import PackageNotFoundError, version

import matplotlib.pyplot as plt
Expand All @@ -30,6 +31,7 @@
powerups,
process_data,
ptable,
rdf,
relevance,
sankey,
scatter,
Expand All @@ -56,6 +58,7 @@
ptable_lines,
ptable_scatters,
)
from pymatviz.rdf import element_pair_rdfs
from pymatviz.relevance import precision_recall_curve, roc_curve
from pymatviz.sankey import sankey_from_2_df_cols
from pymatviz.scatter import (
Expand Down Expand Up @@ -94,7 +97,7 @@
pass # package not installed


IS_IPYTHON = hasattr(__builtins__, "__IPYTHON__")
IS_IPYTHON = hasattr(builtins, "__IPYTHON__")

# define a sensible order for crystal systems across plots
crystal_sys_order = (
Expand Down
186 changes: 186 additions & 0 deletions pymatviz/rdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""This module calculates and plots pairwise radial distribution functions (RDFs) for
pymatgen structures using plotly.
The main function, pairwise_rdfs, generates a plotly figure with facets for each
pair of elements in the given structure. It supports customization of cutoff distance,
bin size, specific element pairs to plot, reference line.
Example usage:
structure = Structure(...) # Create or load a pymatgen Structure
fig = pairwise_rdfs(structure, bin_size=0.1)
fig.show()
"""

from typing import Any

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pymatgen.core import Structure
from scipy.signal import find_peaks


def calculate_rdf(
structure: Structure,
center_species: str,
neighbor_species: str,
cutoff: float,
n_bins: int,
) -> tuple[np.ndarray, np.ndarray]:
"""Calculate the radial distribution function (RDF) for a given pair of species.
The RDF is normalized by the number of pairs and the shell volume density, which
makes the RDF approach 1 for large separations in a homogeneous system.
Args:
structure (Structure): A pymatgen Structure object.
center_species (str): Symbol of the central species.
neighbor_species (str): Symbol of the neighbor species.
cutoff (float): Maximum distance for RDF calculation.
n_bins (int): Number of bins for RDF calculation.
Returns:
tuple[np.ndarray, np.ndarray]: Arrays of (radii, g(r)) values.
"""
bin_size = cutoff / n_bins
radii = np.linspace(0, cutoff, n_bins + 1)[1:]
rdf = np.zeros(n_bins)

center_indices = [
i for i, site in enumerate(structure) if site.specie.symbol == center_species
]
neighbor_indices = [
i for i, site in enumerate(structure) if site.specie.symbol == neighbor_species
]

for center_idx in center_indices:
for neighbor_idx in neighbor_indices:
if center_idx != neighbor_idx:
distance = structure.get_distance(center_idx, neighbor_idx)
if distance < cutoff:
rdf[int(distance / bin_size)] += 1

# Normalize RDF by the number of center-neighbor pairs and shell volumes
rdf = rdf / (len(center_indices) * len(neighbor_indices))
shell_volumes = 4 * np.pi * radii**2 * bin_size
rdf = rdf / (shell_volumes / structure.volume)

return radii, rdf


def find_last_significant_peak(
radii: np.ndarray, rdf: np.ndarray, prominence: float = 0.1
) -> float:
"""Find the position of the last significant peak in the RDF."""
peaks, properties = find_peaks(rdf, prominence=prominence, distance=5)
if peaks.size > 0:
# Sort peaks by prominence and select the last significant one
sorted_peaks = peaks[np.argsort(properties["prominences"])]
return radii[sorted_peaks[-1]]
return radii[-1]


def element_pair_rdfs(
structure: Structure,
cutoff: float = 15,
n_bins: int = 75,
bin_size: float | None = None,
element_pairs: list[tuple[str, str]] | None = None,
reference_line: dict[str, Any] | None = None,
) -> go.Figure:
"""Generate a plotly figure of pairwise radial distribution functions (RDFs) for
all (or a subset of) element pairs in a structure.
The RDF is the probability of finding a neighbor at a distance r from a central
atom. Basically a histogram of pair-wise particle distances.
Args:
structure (Structure): pymatgen Structure.
cutoff (float, optional): Maximum distance for RDF calculation. Default is 15 Å.
n_bins (int, optional): Number of bins for RDF calculation. Default is 75.
bin_size (float, optional): Size of bins for RDF calculation. If specified, it
overrides n_bins. Default is None.
element_pairs (list[tuple[str, str]], optional): Element pairs to plot.
If None, all pairs are plotted.
reference_line (dict, optional): Keywords for reference line at g(r)=1 drawn
with Figure.add_hline(). If None (default), no reference line is drawn.
Returns:
go.Figure: A plotly figure with facets for each pairwise RDF.
Raises:
ValueError: If the structure contains no sites, if invalid element pairs are
provided, or if both n_bins and bin_size are specified.
"""
if not structure.sites:
raise ValueError("input structure contains no sites")

if n_bins != 75 and bin_size is not None:
raise ValueError(
f"Cannot specify both {n_bins=} and {bin_size=}. Pick one or the other."
)

uniq_elements = sorted({site.specie.symbol for site in structure})
element_pairs = element_pairs or [
(e1, e2) for e1 in uniq_elements for e2 in uniq_elements if e1 <= e2
]
element_pairs = sorted(element_pairs)

if extra_elems := {e1 for e1, _e2 in element_pairs} - set(uniq_elements):
raise ValueError(
f"Elements {extra_elems} in element_pairs are not present in the structure"
)

# Calculate pairwise RDFs
if bin_size is not None:
n_bins = int(cutoff / bin_size)
elem_pair_rdfs = {
pair: calculate_rdf(structure, *pair, cutoff, n_bins) for pair in element_pairs
}

# Determine subplot layout
n_pairs = len(element_pairs)
n_cols = min(3, n_pairs)
n_rows = (n_pairs + n_cols - 1) // n_cols

# Create the plotly figure with facets
fig = make_subplots(
rows=n_rows,
cols=n_cols,
subplot_titles=[f"{e1}-{e2}" for e1, e2 in element_pairs],
vertical_spacing=0.25 / n_rows,
horizontal_spacing=0.15 / n_cols,
)

# Add RDF traces to the figure
for idx, (pair, (radii, rdf)) in enumerate(elem_pair_rdfs.items()):
row, col = divmod(idx, n_cols)
row += 1
col += 1

fig.add_scatter(
x=radii,
y=rdf,
mode="lines",
name=f"{pair[0]}-{pair[1]}",
line=dict(color="royalblue"),
showlegend=False,
row=row,
col=col,
hovertemplate="r = %{x:.2f} Å<br>g(r) = %{y:.2f}<extra></extra>",
)

# if one of the last n_col subplots, add x-axis label
if idx >= n_pairs - n_cols:
fig.update_xaxes(title_text="r (Å)", row=row, col=col)

# Add reference line if specified
if reference_line is not None:
defaults = dict(line_dash="dash", line_color="red")
fig.add_hline(y=1, row=row, col=col, **defaults | reference_line)

# set subplot height/width and x/y axis labels
fig.update_layout(height=200 * n_rows, width=350 * n_cols)
fig.update_yaxes(title=dict(text="g(r)", standoff=0.1), col=1)

return fig
5 changes: 4 additions & 1 deletion pymatviz/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from importlib.metadata import PackageNotFoundError, version
from typing import Literal

import matplotlib.pyplot as plt
import plotly.express as px
Expand Down Expand Up @@ -62,7 +63,9 @@
)


def set_plotly_template(template: str | go.layout.Template) -> None:
def set_plotly_template(
template: Literal["pymatviz_white", "pymatviz_dark"] | str | go.layout.Template, # noqa: PYI051
) -> None:
"""Set the default plotly express and graph objects template.
Args:
Expand Down
11 changes: 11 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ See [`pymatviz/xrd.py`](pymatviz/xrd.py).
[xrd-pattern]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern.svg
[xrd-pattern-multiple]: https://github.com/janosh/pymatviz/raw/main/assets/xrd-pattern-multiple.svg

## Radial Distribution Functions

See [`pymatviz/rdf.py`](pymatviz/rdf.py).

| [`rdf_plot(rdf)`](pymatviz/rdf.py) | [`rdf_plot(rdf, rdf2)`](pymatviz/rdf.py) |
| :--------------------------------: | :--------------------------------------: |
| ![element-pair-rdfs-Si16O32] | ![element-pair-rdfs-Na8Nb8O24] |

[element-pair-rdfs-Si16O32]: examples/make_assets/element-pair-rdfs-Si16O32.svg
[element-pair-rdfs-Na8Nb8O24]: examples/make_assets/element-pair-rdfs-Na8Nb8O24.svg

## Uncertainty

See [`pymatviz/uncertainty.py`](pymatviz/uncertainty.py).
Expand Down
Loading

0 comments on commit b551581

Please sign in to comment.