Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equalize Phonon(Dos|BS)Plotter colors, allow custom plot settings per-DOS #3514

Merged
merged 6 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pymatgen/phonon/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _positive_densities(self) -> np.ndarray:
"""Numpy array containing the list of densities corresponding to positive frequencies."""
return self.densities[self.ind_zero_freq :]

def cv(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
def cv(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
"""Constant volume specific heat C_v at temperature T obtained from the integration of the DOS.
Only positive frequencies will be used.
Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number
Expand Down Expand Up @@ -198,7 +198,7 @@ def csch2(x):

return cv

def entropy(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
def entropy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
"""Vibrational entropy at temperature T obtained from the integration of the DOS.
Only positive frequencies will be used.
Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number
Expand Down Expand Up @@ -233,7 +233,7 @@ def entropy(self, temp: float, structure: Structure | None = None, **kwargs) ->

return entropy

def internal_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
def internal_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
"""Phonon contribution to the internal energy at temperature T obtained from the integration of the DOS.
Only positive frequencies will be used.
Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number
Expand Down Expand Up @@ -268,7 +268,7 @@ def internal_energy(self, temp: float, structure: Structure | None = None, **kwa

return e_phonon

def helmholtz_free_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float:
def helmholtz_free_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float:
"""Phonon contribution to the Helmholtz free energy at temperature T obtained from the integration of the DOS.
Only positive frequencies will be used.
Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number
Expand Down
94 changes: 40 additions & 54 deletions pymatgen/phonon/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import matplotlib.pyplot as plt
import numpy as np
import palettable
import scipy.constants as const
from matplotlib.collections import LineCollection
from monty.json import jsanitize
Expand Down Expand Up @@ -95,19 +94,18 @@ def __init__(self, stack: bool = False, sigma: float | None = None) -> None:
)
self.stack = stack
self.sigma = sigma
self._doses: dict[str, dict[Literal["frequencies", "densities"], np.ndarray]] = {}
self._doses: dict[str, dict[str, np.ndarray]] = {}

def add_dos(self, label: str, dos: PhononDos) -> None:
def add_dos(self, label: str, dos: PhononDos, **kwargs: Any) -> None:
"""Adds a dos for plotting.

Args:
label:
label for the DOS. Must be unique.
dos:
PhononDos object
label (str): label for the DOS. Must be unique.
dos (PhononDos): DOS object
**kwargs: kwargs supported by matplotlib.pyplot.plot
"""
densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities
self._doses[label] = {"frequencies": dos.frequencies, "densities": densities}
self._doses[label] = {"frequencies": dos.frequencies, "densities": densities, **kwargs}

def add_dos_dict(self, dos_dict: dict, key_sort_func=None) -> None:
"""Add a dictionary of doses, with an optional sorting function for the
Expand Down Expand Up @@ -160,8 +158,6 @@ def get_plot(
n_colors = max(3, len(self._doses))
n_colors = min(9, n_colors)

colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors

y = None
all_densities = []
all_frequencies = []
Expand All @@ -186,18 +182,14 @@ def get_plot(
all_densities.reverse()
all_frequencies.reverse()
all_pts = []
colors = ("blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive")
for idx, (key, frequencies, densities) in enumerate(zip(keys, all_frequencies, all_densities)):
color = self._doses[key].get("color", colors[idx % n_colors])
all_pts.extend(list(zip(frequencies, densities)))
if self.stack:
ax.fill(frequencies, densities, color=colors[idx % n_colors], label=str(key))
ax.fill(frequencies, densities, color=color, label=str(key))
else:
ax.plot(
frequencies,
densities,
color=colors[idx % n_colors],
label=str(key),
linewidth=3,
)
ax.plot(frequencies, densities, color=color, label=str(key), linewidth=3)

if xlim:
ax.set_xlim(xlim)
Expand Down Expand Up @@ -297,13 +289,9 @@ def _make_ticks(self, ax: Axes) -> Axes:
ax.set_xticks(uniq_d)
ax.set_xticklabels(uniq_l)

for idx in range(len(ticks["label"])):
if ticks["label"][idx] is not None:
# don't print the same label twice
if idx != 0:
ax.axvline(ticks["distance"][idx], color="k")
else:
ax.axvline(ticks["distance"][idx], color="k")
for idx, label in enumerate(ticks["label"]):
if label is not None:
ax.axvline(ticks["distance"][idx], color="k")
return ax

def bs_plot_data(self) -> dict[str, Any]:
Expand Down Expand Up @@ -356,14 +344,11 @@ def get_plot(
ax = pretty_plot(12, 8)

data = self.bs_plot_data()
for d in range(len(data["distances"])):
kwargs.setdefault("color", "blue")
for dists, freqs in zip(data["distances"], data["frequency"]):
for idx in range(self._nb_bands):
ax.plot(
data["distances"][d],
[data["frequency"][d][idx][j] * u.factor for j in range(len(data["distances"][d]))],
"b-",
**kwargs,
)
ys = [freqs[idx][j] * u.factor for j in range(len(dists))]
ax.plot(dists, ys, **kwargs)

self._make_ticks(ax)

Expand Down Expand Up @@ -598,15 +583,15 @@ def get_ticks(self) -> dict[str, list]:
label0 = f"${label0}$"
tick_labels.pop()
tick_distance.pop()
tick_labels.append(f"{label0}$\\mid${label1}")
tick_labels.append(f"{label0}|{label1}")
elif point.label.startswith("\\") or point.label.find("_") != -1:
tick_labels.append(f"${point.label}$")
else:
# map atomate2 all-upper-case point.labels to pretty LaTeX
label = dict(GAMMA=r"$\Gamma$", DELTA=r"$\Delta$").get(point.label, point.label)
tick_labels.append(label)
tick_labels.append(point.label)
previous_label = point.label
previous_branch = this_branch
# map atomate2 all-upper-case labels like GAMMA/DELTA to pretty symbols
tick_labels = [label.replace("GAMMA", "Γ").replace("DELTA", "Δ").replace("SIGMA", "Σ") for label in tick_labels]
return {"distance": tick_distance, "label": tick_labels}

def plot_compare(
Expand All @@ -616,6 +601,7 @@ def plot_compare(
labels: tuple[str, str] | None = None,
legend_kwargs: dict | None = None,
on_incompatible: Literal["raise", "warn", "ignore"] = "raise",
other_kwargs: dict | None = None,
**kwargs,
) -> Axes:
"""Plot two band structure for comparison. One is in red the other in blue.
Expand All @@ -634,14 +620,16 @@ def plot_compare(
legend_kwargs: dict[str, Any]: kwargs passed to ax.legend().
on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the two band structures are not compatible.
Defaults to 'raise'.
other_kwargs: dict[str, Any]: kwargs passed to other_plotter ax.plot().
**kwargs: passed to ax.plot().

Returns:
a matplotlib object with both band structures
"""
unit = freq_units(units)
legend_kwargs = legend_kwargs or {}
legend_kwargs.setdefault("fontsize", 22)
other_kwargs = other_kwargs or {}
legend_kwargs.setdefault("fontsize", 20)

data_orig = self.bs_plot_data()
data = other_plotter.bs_plot_data()
Expand All @@ -656,24 +644,22 @@ def plot_compare(
line_width = kwargs.setdefault("linewidth", 1)

ax = self.get_plot(units=units, **kwargs)
for band_idx in range(other_plotter._nb_bands):
for dist_idx in range(len(data_orig["distances"])):
ax.plot(
data_orig["distances"][dist_idx],
[
data["frequency"][dist_idx][band_idx][j] * unit.factor
for j in range(len(data_orig["distances"][dist_idx]))
],
"r-",
**kwargs,
)

# add legend showing which color correspond to which band structure
if labels is None and self._label and other_plotter._label:
labels = (self._label, other_plotter._label)
if labels:
ax.plot([], [], "b-", label=labels[0], linewidth=3 * line_width)
ax.plot([], [], "r-", label=labels[1], linewidth=3 * line_width)
kwargs.setdefault("color", "red") # don't move this line up! it would mess up self.get_plot color

for band_idx in range(other_plotter._nb_bands):
for dist_idx, dists in enumerate(data_orig["distances"]):
xs = dists
ys = [data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))]
ax.plot(xs, ys, **(kwargs | other_kwargs))

# add legend showing which color corresponds to which band structure
if labels or (self._label and other_plotter._label):
color_self, color_other = ax.lines[0].get_color(), ax.lines[-1].get_color()
label_self, label_other = labels or (self._label, other_plotter._label)
ax.plot([], [], label=label_self, linewidth=2 * line_width, color=color_self)
linestyle = other_kwargs.get("linestyle", "-")
ax.plot([], [], label=label_other, linewidth=2 * line_width, color=color_other, linestyle=linestyle)
ax.legend(**legend_kwargs)

return ax
Expand Down