Skip to content

Commit

Permalink
Merge branch 'main' into brillouin-zone-3d
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Nov 29, 2024
2 parents a526a72 + d6f6bef commit 15c2642
Show file tree
Hide file tree
Showing 17 changed files with 283 additions and 65 deletions.
59 changes: 58 additions & 1 deletion assets/scripts/ptable_plotly/ptable_heatmap_splits_plotly.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# %%
import itertools
from collections.abc import Callable, Sequence

import numpy as np
from pymatgen.core.periodic_table import Element
from pymatgen.core import Element

import pymatviz as pmv
import pymatviz.colors as pmv_colors
from pymatviz.typing import RgbColorType


np_rng = np.random.default_rng(seed=0)
Expand Down Expand Up @@ -34,3 +37,57 @@
fig.show()
if orientation == "diagonal":
pmv.io.save_and_compress_svg(fig, f"ptable-heatmap-splits-plotly-{n_splits}")


# %% Visualize multiple element color schemes on a split periodic table heatmap
def make_color_scale(
color_schemes: Sequence[dict[str, RgbColorType]],
) -> Callable[[str, float, int], str]:
"""Return element colors in different palettes based on split index."""

def elem_color_scale(element: str, _val: float, split_idx: int) -> str:
color = color_schemes[split_idx].get(element)
if color is None:
raise ValueError(f"no color for {element=} in {split_idx=}")
return f"rgb{color}"

return elem_color_scale


palettes_3 = (
pmv_colors.ELEM_COLORS_ALLOY,
pmv_colors.ELEM_COLORS_JMOL,
pmv_colors.ELEM_COLORS_VESTA,
)

fig = pmv.ptable_heatmap_splits_plotly(
# Use dummy values for all elements
{str(elem): list(range(len(palettes_3))) for elem in Element},
orientation="diagonal", # could also use "grid"
colorscale=make_color_scale(palettes_3),
hover_data=dict.fromkeys(
map(str, Element), "top left: JMOL<br>top right: VESTA, bottom: ALLOY"
),
)
title = (
"<b>Element color schemes</b><br>top left: JMOL, top right: VESTA, bottom: ALLOY"
)
fig.layout.title.update(text=title, x=0.4, y=0.8)
fig.show()
pmv.io.save_and_compress_svg(fig, "ptable-heatmap-splits-plotly-3-color-schemes")


# %% Visualize multiple element color schemes on a split periodic table heatmap
palettes_2 = (pmv_colors.ELEM_COLORS_ALLOY, pmv_colors.ELEM_COLORS_VESTA)

fig = pmv.ptable_heatmap_splits_plotly(
# Use dummy values for all elements
{str(elem): list(range(len(palettes_2))) for elem in Element},
orientation="vertical",
colorscale=make_color_scale(palettes_2),
hover_data=dict.fromkeys(map(str, Element), "left: VESTA<br>right: ALLOY"),
)
title = "<b>Element color schemes</b><br>left: VESTA, right: ALLOY"
fig.layout.title.update(text=title, x=0.4, y=0.8)
fig.show()
pmv.io.save_and_compress_svg(fig, "ptable-heatmap-splits-plotly-2-color-schemes")
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
72 changes: 63 additions & 9 deletions pymatviz/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if TYPE_CHECKING:
from typing import Final

from matplotlib.typing import ColorType
from pymatviz.typing import Rgb256ColorType, RgbColorType


# Element type based colors
Expand All @@ -32,7 +32,7 @@

# The following element-based colors are copied from elementari:
# https://github.com/janosh/elementari/blob/85a044cd/src/lib/colors.ts#L20-L242
ELEM_COLORS_JMOL: dict[str, ColorType] = {
ELEM_COLORS_JMOL_256: dict[str, Rgb256ColorType] = {
"H": (255, 255, 255),
"He": (217, 255, 255),
"Li": (204, 128, 255),
Expand Down Expand Up @@ -145,12 +145,12 @@
}

# Scale color value to [0, 1] for matplotlib
ELEM_COLORS_JMOL = {
elem: tuple(color / 255 for color in colors)
for elem, colors in ELEM_COLORS_JMOL.items()
ELEM_COLORS_JMOL: dict[str, RgbColorType] = {
elem: (r / 255, g / 255, b / 255)
for elem, (r, g, b) in ELEM_COLORS_JMOL_256.items()
}

ELEM_COLORS_VESTA: dict[str, ColorType] = {
ELEM_COLORS_VESTA_256: dict[str, Rgb256ColorType] = {
"Ac": (112, 171, 250),
"Ag": (192, 192, 192),
"Al": (129, 178, 214),
Expand Down Expand Up @@ -262,7 +262,61 @@
"Zr": (0, 255, 0),
}

ELEM_COLORS_VESTA = {
elem: tuple(color / 255 for color in colors)
for elem, colors in ELEM_COLORS_VESTA.items()
ELEM_COLORS_VESTA: dict[str, RgbColorType] = {
elem: (r / 255, g / 255, b / 255)
for elem, (r, g, b) in ELEM_COLORS_VESTA_256.items()
}


# High-contrast color scheme optimized for metal alloys while preserving some familiar
# colors. Merge with ELEM_COLORS_VESTA to get a complete color scheme while only
# overriding metal colors.
ELEM_COLORS_ALLOY_256: dict[str, Rgb256ColorType] = ELEM_COLORS_VESTA_256 | {
# Alkali metals - bright purples
"Li": (0, 53, 0), # Bright purple
"Na": (0, 41, 255), # Deep purple
"K": (0, 255, 0), # Royal purple
"Rb": (0, 255, 255), # Dark purple
"Cs": (255, 0, 0), # Deep violet
# Alkaline earth metals - yellows/oranges
"Be": (255, 0, 255), # Golden yellow
"Mg": (255, 255, 0), # Dark orange
"Ca": (255, 255, 255), # Bright orange
"Sr": (38, 154, 0), # Red-orange
"Ba": (0, 150, 255), # Pure red
# Transition metals - maximizing contrast
"Sc": (207, 26, 128), # Light gray (from JMOL)
"Ti": (216, 219, 127), # Purple (changed from blue for more contrast with Zr)
"V": (255, 150, 0), # Pink
"Cr": (197, 163, 255), # Bright green
"Mn": (0, 46, 133), # Magenta
"Fe": (0, 151, 134), # Bright orange (changed from JMOL)
"Co": (0, 255, 121), # Deep blue
"Ni": (99, 0, 62), # Orange (changed from green for contrast with Zr)
"Cu": (129, 0, 255), # Brown (changed from JMOL)
"Zn": (168, 74, 0), # Light blue
"Zr": (108, 96, 208), # Cyan (kept)
"Nb": (134, 228, 15), # Purple (new)
# Post-transition metals - earth tones
"Al": (102, 211, 188), # Gray (from JMOL)
"Ga": (255, 121, 143), # Rose
"In": (131, 143, 93), # Dusty rose
"Sn": (197, 163, 255), # Dark orange (changed from blue-gray for contrast with Zr)
"Tl": (0, 46, 133), # Terra cotta
"Pb": (0, 151, 134), # Dark gray
"Bi": (0, 255, 121), # Purple
# Noble metals - preserving traditional colors
"Ru": (99, 0, 62), # Teal
"Rh": (129, 0, 255), # Hot pink
"Pd": (168, 74, 0), # Blue (from JMOL)
"Ag": (108, 96, 208), # Silver (from JMOL)
"Os": (134, 228, 15), # Blue (from JMOL)
"Ir": (102, 211, 188), # Dark blue (from JMOL)
"Pt": (255, 121, 143), # Light gray (from JMOL)
"Au": (131, 143, 93), # Gold (from JMOL)
}

ELEM_COLORS_ALLOY: dict[str, RgbColorType] = {
elem: (r / 255, g / 255, b / 255)
for elem, (r, g, b) in ELEM_COLORS_ALLOY_256.items()
}
20 changes: 7 additions & 13 deletions pymatviz/coordination/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from plotly.subplots import make_subplots
from pymatgen.analysis.local_env import NearNeighbors

from pymatviz.colors import ELEM_COLORS_JMOL, ELEM_COLORS_VESTA
from pymatviz.colors import ELEM_COLORS_JMOL
from pymatviz.coordination.helpers import (
CnSplitMode,
calculate_average_cn,
Expand Down Expand Up @@ -145,14 +145,10 @@ def coordination_hist(
if isinstance(element_color_scheme, dict):
# Merge custom colors with default Jmol colors to get a complete color scheme
element_colors = ELEM_COLORS_JMOL | element_color_scheme
elif element_color_scheme == ElemColorScheme.jmol:
element_colors = ELEM_COLORS_JMOL
elif element_color_scheme == ElemColorScheme.vesta:
element_colors = ELEM_COLORS_VESTA
elif isinstance(element_color_scheme, dict):
element_colors = element_color_scheme
elif isinstance(element_color_scheme, ElemColorScheme):
element_colors = element_color_scheme.color_map
else:
raise ValueError(
raise TypeError(
f"Invalid {element_color_scheme=}. Must be {', '.join(ElemColorScheme)} "
f"or a custom dict."
)
Expand Down Expand Up @@ -403,12 +399,10 @@ def coordination_vs_cutoff_line(

if isinstance(element_color_scheme, dict):
element_colors = ELEM_COLORS_JMOL | element_color_scheme
elif element_color_scheme == ElemColorScheme.jmol:
element_colors = ELEM_COLORS_JMOL
elif element_color_scheme == ElemColorScheme.vesta:
element_colors = ELEM_COLORS_VESTA
elif isinstance(element_color_scheme, ElemColorScheme):
element_colors = element_color_scheme.color_map
else:
raise ValueError(
raise TypeError(
f"Invalid {element_color_scheme=}. Must be {', '.join(ElemColorScheme)} "
"or a custom dict."
)
Expand Down
21 changes: 19 additions & 2 deletions pymatviz/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from typing_extensions import Self

from pymatviz.typing import RgbColorType

# TODO: remove following definition of StrEnum once Python 3.11+
if sys.version_info >= (3, 11):
from enum import StrEnum
Expand Down Expand Up @@ -190,6 +192,8 @@ class Key(LabelEnum):
symmetry_decrease = "symmetry_decrease", "Symmetry Decrease"
symmetry_increase = "symmetry_increase", "Symmetry Increase"
symmetry_match = "symmetry_match", "Symmetry Match"
symprec = "symprec", "Symmetry Precision"
angle_tolerance = "angle_tolerance", "Angle Tolerance"
point_group = "point_group", "Point Group"
n_wyckoff = "n_wyckoff", "Number of Wyckoff Positions"
n_rot_syms = "n_rot_syms", "Number of rotational symmetries"
Expand Down Expand Up @@ -531,6 +535,10 @@ class Key(LabelEnum):
mse = "MSE", "Mean Squared Error"
rmse = "RMSE", "Root Mean Squared Error"
rmsd = "rmsd", "Root Mean Square Deviation"
n_sym_ops_mae = (
"n_sym_ops_mae",
"Mean Absolute Error in Number of Symmetry Operations",
)
structure_rmsd = "structure_rmsd", f"Structure RMSD {angstrom}"
mape = "MAPE", "Mean Absolute Percentage Error"
srme = "SRME", "Symmetric Relative Mean Error"
Expand Down Expand Up @@ -785,10 +793,19 @@ class ElemColorScheme(LabelEnum):
"""

# key, label, color
# from https://wikipedia.org/wiki/Jmol"
jmol = "jmol", "Jmol", "Java-based molecular visualization"
# https://wikipedia.org/wiki/Jmol"
# from https://jp-minerals.org/vesta
vesta = "vesta", "VESTA", "Visualization for Electronic Structural Analysis"
# https://jp-minerals.org/vesta
# custom made for pymatviz
alloy = "alloy", "Alloy", "High-contrast color scheme optimized for metal alloys"

@property
def color_map(self) -> dict[str, RgbColorType]:
"""Return map from element symbol to color."""
import pymatviz.colors as pmv_colors

return getattr(pmv_colors, f"ELEM_COLORS_{self.value.upper()}")


@unique
Expand Down
47 changes: 36 additions & 11 deletions pymatviz/ptable/ptable_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import TYPE_CHECKING
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, TypeAlias

import numpy as np
import pandas as pd
Expand All @@ -24,10 +24,14 @@


if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any, Literal


ColorScale: TypeAlias = (
str | Sequence[str] | Sequence[tuple[float, str]] | Callable[[str, float, int], str]
)


def ptable_heatmap_plotly(
values: ElemValues,
*,
Expand Down Expand Up @@ -683,10 +687,10 @@ def ptable_hists_plotly(
def ptable_heatmap_splits_plotly(
data: pd.DataFrame | pd.Series | dict[str, list[float]],
*,
# Split-specific
# Split
orientation: Literal["diagonal", "horizontal", "vertical", "grid"] = "diagonal",
# Figure
colorscale: str | Sequence[str] | Sequence[tuple[float, str]] = "viridis",
colorscale: ColorScale = "viridis",
colorbar: dict[str, Any] | Literal[False] | None = None,
on_empty: Literal["hide", "show"] = "hide",
hide_f_block: bool | Literal["auto"] = "auto",
Expand Down Expand Up @@ -714,8 +718,12 @@ def ptable_heatmap_splits_plotly(
1st value would be plotted in lower-left corner, 2nd in the upper-right.
--- Figure ---
colorscale (str | list[str] | list[tuple[float, str]]): Color scale for heatmap.
Defaults to "viridis".
colorscale (ColorScale): Color scale for heatmap. Defaults to "viridis". Can be:
- str: Name of built-in colorscale ("turbo", "inferno", "plasma", ...)
- list[str]: List of colors to interpolate between
- list[tuple[float, str]]: List of (position, color) pairs
- Callable[[str, float, int], str]: Function mapping (element symbol, split
value, split index) to color string. Useful for custom color schemes.
colorbar (dict[str, Any] | None): Plotly colorbar properties. Defaults to
dict(orientation="h"). See https://plotly.com/python/reference#heatmap-colorbar
for available options. Set to False to hide the colorbar.
Expand Down Expand Up @@ -750,8 +758,8 @@ def ptable_heatmap_splits_plotly(
--- Additional options ---
nan_color (str): Color for NaN values. Defaults to "#eff".
hover_data (dict[str, str | int | float] | pd.Series): Additional data for
hover tooltip.
hover_data (dict[str, str] | pd.Series): Map from element symbol to hover text.
to additional text to append to hover tooltip.
subplot_kwargs (dict | None): Additional keywords passed to make_subplots().
Returns:
Expand All @@ -762,6 +770,7 @@ def ptable_heatmap_splits_plotly(
ValueError: If n_splits not in {2, 3, 4} or orientation="grid" with n_splits!=4
"""
import plotly.colors
from pymatgen.core import Element

if isinstance(data, pd.Series): # Process input data
data = data.to_dict()
Expand All @@ -782,6 +791,15 @@ def ptable_heatmap_splits_plotly(
) | (subplot_kwargs or {})
fig = make_subplots(**subplot_kwargs)

# warn about unrecognized element symbols
unrecognized_element_symbols = set(data) - {*map(str, Element)}
if unrecognized_element_symbols:
warnings.warn(
f"{unrecognized_element_symbols=}\nShould be simple strings of element "
"symbols",
stacklevel=2,
)

def create_section_coords(
n_splits: Literal[2, 3, 4],
orientation: Literal["diagonal", "horizontal", "vertical", "grid"],
Expand Down Expand Up @@ -871,13 +889,18 @@ def create_section_coords(

# Create sections
sections = create_section_coords(len(values), orientation) # type: ignore[arg-type]
for idx, (xs, ys) in enumerate(sections):
for idx, (xs, ys) in enumerate(sections): # Loop over element tile splits
if len(values) <= idx or np.isnan(values[idx]):
color = nan_color
elif callable(colorscale):
# Use the callable to get color directly
color = colorscale(symbol, values[idx], idx)
else:
# Use plotly builtin color interpolation logic
color = plotly.colors.sample_colorscale(
colorscale, (values[idx] - cbar_min) / (cbar_max - cbar_min)
)[0]

fig.add_scatter(
x=xs,
y=ys,
Expand Down Expand Up @@ -972,7 +995,9 @@ def create_section_coords(
)

# Add colorbar
if colorbar is not False:
if colorbar is not False and not callable(colorscale):
# TODO don't skip colorbar if colorscale is callable. problem: can't sample and
# interpolate callable to get color strings since it could be discrete
colorbar = dict(orientation="h", lenmode="fraction", thickness=15) | (
colorbar or {}
)
Expand Down
Loading

0 comments on commit 15c2642

Please sign in to comment.