From 669568ba48b9df5075ab8b7d9b321e9be8a4fe40 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 13 Dec 2023 15:05:30 -0800 Subject: [PATCH 1/6] make temp optional to allow falling back to t if temp not passed --- pymatgen/phonon/dos.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymatgen/phonon/dos.py b/pymatgen/phonon/dos.py index dcdfb13d645..c08e7f8aca7 100644 --- a/pymatgen/phonon/dos.py +++ b/pymatgen/phonon/dos.py @@ -161,7 +161,7 @@ def _positive_densities(self) -> np.ndarray: """Numpy array containing the list of densities corresponding to positive frequencies.""" return self.densities[self.ind_zero_freq :] - def cv(self, temp: float, structure: Structure | None = None, **kwargs) -> float: + def cv(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float: """Constant volume specific heat C_v at temperature T obtained from the integration of the DOS. Only positive frequencies will be used. Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number @@ -198,7 +198,7 @@ def csch2(x): return cv - def entropy(self, temp: float, structure: Structure | None = None, **kwargs) -> float: + def entropy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float: """Vibrational entropy at temperature T obtained from the integration of the DOS. Only positive frequencies will be used. Result in J/(K*mol-c). A mol-c is the abbreviation of a mole-cell, that is, the number @@ -233,7 +233,7 @@ def entropy(self, temp: float, structure: Structure | None = None, **kwargs) -> return entropy - def internal_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float: + def internal_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float: """Phonon contribution to the internal energy at temperature T obtained from the integration of the DOS. Only positive frequencies will be used. Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number @@ -268,7 +268,7 @@ def internal_energy(self, temp: float, structure: Structure | None = None, **kwa return e_phonon - def helmholtz_free_energy(self, temp: float, structure: Structure | None = None, **kwargs) -> float: + def helmholtz_free_energy(self, temp: float | None = None, structure: Structure | None = None, **kwargs) -> float: """Phonon contribution to the Helmholtz free energy at temperature T obtained from the integration of the DOS. Only positive frequencies will be used. Result in J/mol-c. A mol-c is the abbreviation of a mole-cell, that is, the number From cb7f2bc4d775acdc50a6a682f175cc810cb4ccf9 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 13 Dec 2023 15:06:38 -0800 Subject: [PATCH 2/6] allow passing arbitrary kwargs into PhononDosPlotter.add_dos for use in e.g. color customization --- pymatgen/phonon/plotter.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index 784ee8f866a..e8da61a8db9 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -95,19 +95,18 @@ def __init__(self, stack: bool = False, sigma: float | None = None) -> None: ) self.stack = stack self.sigma = sigma - self._doses: dict[str, dict[Literal["frequencies", "densities"], np.ndarray]] = {} + self._doses: dict[str, dict[str, np.ndarray]] = {} - def add_dos(self, label: str, dos: PhononDos) -> None: + def add_dos(self, label: str, dos: PhononDos, **kwargs: Any) -> None: """Adds a dos for plotting. Args: - label: - label for the DOS. Must be unique. - dos: - PhononDos object + label (str): label for the DOS. Must be unique. + dos (PhononDos): DOS object + **kwargs: kwargs supported by matplotlib.pyplot.plot """ densities = dos.get_smeared_densities(self.sigma) if self.sigma else dos.densities - self._doses[label] = {"frequencies": dos.frequencies, "densities": densities} + self._doses[label] = {"frequencies": dos.frequencies, "densities": densities, **kwargs} def add_dos_dict(self, dos_dict: dict, key_sort_func=None) -> None: """Add a dictionary of doses, with an optional sorting function for the From 64cd6ced618bd77893334ca8aa07b69ce6b832f9 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 13 Dec 2023 15:08:04 -0800 Subject: [PATCH 3/6] change default line colors of PhononDosPlotter and PhononBSPlotter to tab:10 tab:blue and tab:orange in particular --- pymatgen/phonon/plotter.py | 66 ++++++++++++++------------------------ 1 file changed, 24 insertions(+), 42 deletions(-) diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index e8da61a8db9..eeaf2c1bcf8 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -8,7 +8,6 @@ import matplotlib.pyplot as plt import numpy as np -import palettable import scipy.constants as const from matplotlib.collections import LineCollection from monty.json import jsanitize @@ -159,8 +158,6 @@ def get_plot( n_colors = max(3, len(self._doses)) n_colors = min(9, n_colors) - colors = palettable.colorbrewer.qualitative.Set1_9.mpl_colors - y = None all_densities = [] all_frequencies = [] @@ -186,17 +183,12 @@ def get_plot( all_frequencies.reverse() all_pts = [] for idx, (key, frequencies, densities) in enumerate(zip(keys, all_frequencies, all_densities)): + color = self._doses[key].get("color", plt.cm.tab10.colors[idx % n_colors]) all_pts.extend(list(zip(frequencies, densities))) if self.stack: - ax.fill(frequencies, densities, color=colors[idx % n_colors], label=str(key)) + ax.fill(frequencies, densities, color=color, label=str(key)) else: - ax.plot( - frequencies, - densities, - color=colors[idx % n_colors], - label=str(key), - linewidth=3, - ) + ax.plot(frequencies, densities, color=color, label=str(key), linewidth=3) if xlim: ax.set_xlim(xlim) @@ -296,13 +288,9 @@ def _make_ticks(self, ax: Axes) -> Axes: ax.set_xticks(uniq_d) ax.set_xticklabels(uniq_l) - for idx in range(len(ticks["label"])): - if ticks["label"][idx] is not None: - # don't print the same label twice - if idx != 0: - ax.axvline(ticks["distance"][idx], color="k") - else: - ax.axvline(ticks["distance"][idx], color="k") + for idx, label in enumerate(ticks["label"]): + if label is not None: + ax.axvline(ticks["distance"][idx], color="k") return ax def bs_plot_data(self) -> dict[str, Any]: @@ -355,14 +343,11 @@ def get_plot( ax = pretty_plot(12, 8) data = self.bs_plot_data() - for d in range(len(data["distances"])): + kwargs.setdefault("color", "tab:blue") + for dists, freqs in zip(data["distances"], data["frequency"]): for idx in range(self._nb_bands): - ax.plot( - data["distances"][d], - [data["frequency"][d][idx][j] * u.factor for j in range(len(data["distances"][d]))], - "b-", - **kwargs, - ) + ys = [freqs[idx][j] * u.factor for j in range(len(dists))] + ax.plot(dists, ys, **kwargs) self._make_ticks(ax) @@ -655,24 +640,21 @@ def plot_compare( line_width = kwargs.setdefault("linewidth", 1) ax = self.get_plot(units=units, **kwargs) - for band_idx in range(other_plotter._nb_bands): - for dist_idx in range(len(data_orig["distances"])): - ax.plot( - data_orig["distances"][dist_idx], - [ - data["frequency"][dist_idx][band_idx][j] * unit.factor - for j in range(len(data_orig["distances"][dist_idx])) - ], - "r-", - **kwargs, - ) - # add legend showing which color correspond to which band structure - if labels is None and self._label and other_plotter._label: - labels = (self._label, other_plotter._label) - if labels: - ax.plot([], [], "b-", label=labels[0], linewidth=3 * line_width) - ax.plot([], [], "r-", label=labels[1], linewidth=3 * line_width) + kwargs.setdefault("color", "tab:orange") # don't move this line up! it would mess up self.get_plot color + + for band_idx in range(other_plotter._nb_bands): + for dist_idx, dists in enumerate(data_orig["distances"]): + xs = dists + ys = [data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))] + ax.plot(xs, ys, **kwargs) + + # add legend showing which color corresponds to which band structure + if labels or (self._label and other_plotter._label): + color_self, color_other = ax.lines[0].get_color(), ax.lines[-1].get_color() + label_self, label_other = labels or (self._label, other_plotter._label) + ax.plot([], [], label=label_self, linewidth=3 * line_width, color=color_self) + ax.plot([], [], label=label_other, linewidth=3 * line_width, color=color_other) ax.legend(**legend_kwargs) return ax From 6deff781bd4dd15de7a7fd5f47c2772ce8ae4275 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 13 Dec 2023 15:09:05 -0800 Subject: [PATCH 4/6] fix overlapping an non-symbol band struct x-labels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit label.replace("GAMMA", "Γ").replace("DELTA", "Δ") --- pymatgen/phonon/plotter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index eeaf2c1bcf8..12f4a029dd7 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -582,15 +582,15 @@ def get_ticks(self) -> dict[str, list]: label0 = f"${label0}$" tick_labels.pop() tick_distance.pop() - tick_labels.append(f"{label0}$\\mid${label1}") + tick_labels.append(f"{label0}|{label1}") elif point.label.startswith("\\") or point.label.find("_") != -1: tick_labels.append(f"${point.label}$") else: - # map atomate2 all-upper-case point.labels to pretty LaTeX - label = dict(GAMMA=r"$\Gamma$", DELTA=r"$\Delta$").get(point.label, point.label) - tick_labels.append(label) + tick_labels.append(point.label) previous_label = point.label previous_branch = this_branch + # map atomate2 all-upper-case labels like GAMMA/DELTA to pretty symbols + tick_labels = [label.replace("GAMMA", "Γ").replace("DELTA", "Δ") for label in tick_labels] return {"distance": tick_distance, "label": tick_labels} def plot_compare( From 40e0b9230d77de8966b24892410db13961ee7d1e Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 13 Dec 2023 16:42:09 -0800 Subject: [PATCH 5/6] change colors from tab10 back to regular red/blue --- pymatgen/phonon/plotter.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index 12f4a029dd7..9a676d9c069 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -182,8 +182,9 @@ def get_plot( all_densities.reverse() all_frequencies.reverse() all_pts = [] + colors = ("blue", "red", "green", "orange", "purple", "brown", "pink", "gray", "olive") for idx, (key, frequencies, densities) in enumerate(zip(keys, all_frequencies, all_densities)): - color = self._doses[key].get("color", plt.cm.tab10.colors[idx % n_colors]) + color = self._doses[key].get("color", colors[idx % n_colors]) all_pts.extend(list(zip(frequencies, densities))) if self.stack: ax.fill(frequencies, densities, color=color, label=str(key)) @@ -343,7 +344,7 @@ def get_plot( ax = pretty_plot(12, 8) data = self.bs_plot_data() - kwargs.setdefault("color", "tab:blue") + kwargs.setdefault("color", "blue") for dists, freqs in zip(data["distances"], data["frequency"]): for idx in range(self._nb_bands): ys = [freqs[idx][j] * u.factor for j in range(len(dists))] @@ -641,7 +642,7 @@ def plot_compare( ax = self.get_plot(units=units, **kwargs) - kwargs.setdefault("color", "tab:orange") # don't move this line up! it would mess up self.get_plot color + kwargs.setdefault("color", "red") # don't move this line up! it would mess up self.get_plot color for band_idx in range(other_plotter._nb_bands): for dist_idx, dists in enumerate(data_orig["distances"]): From 31c0a9bf6180a79eb5e52de67082ac0f801d5176 Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 13 Dec 2023 16:42:51 -0800 Subject: [PATCH 6/6] plot_compare add keyword other_kwargs to customize 2nd set of band lines --- pymatgen/phonon/plotter.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pymatgen/phonon/plotter.py b/pymatgen/phonon/plotter.py index 9a676d9c069..a89972fd24c 100644 --- a/pymatgen/phonon/plotter.py +++ b/pymatgen/phonon/plotter.py @@ -591,7 +591,7 @@ def get_ticks(self) -> dict[str, list]: previous_label = point.label previous_branch = this_branch # map atomate2 all-upper-case labels like GAMMA/DELTA to pretty symbols - tick_labels = [label.replace("GAMMA", "Γ").replace("DELTA", "Δ") for label in tick_labels] + tick_labels = [label.replace("GAMMA", "Γ").replace("DELTA", "Δ").replace("SIGMA", "Σ") for label in tick_labels] return {"distance": tick_distance, "label": tick_labels} def plot_compare( @@ -601,6 +601,7 @@ def plot_compare( labels: tuple[str, str] | None = None, legend_kwargs: dict | None = None, on_incompatible: Literal["raise", "warn", "ignore"] = "raise", + other_kwargs: dict | None = None, **kwargs, ) -> Axes: """Plot two band structure for comparison. One is in red the other in blue. @@ -619,6 +620,7 @@ def plot_compare( legend_kwargs: dict[str, Any]: kwargs passed to ax.legend(). on_incompatible ('raise' | 'warn' | 'ignore'): What to do if the two band structures are not compatible. Defaults to 'raise'. + other_kwargs: dict[str, Any]: kwargs passed to other_plotter ax.plot(). **kwargs: passed to ax.plot(). Returns: @@ -626,7 +628,8 @@ def plot_compare( """ unit = freq_units(units) legend_kwargs = legend_kwargs or {} - legend_kwargs.setdefault("fontsize", 22) + other_kwargs = other_kwargs or {} + legend_kwargs.setdefault("fontsize", 20) data_orig = self.bs_plot_data() data = other_plotter.bs_plot_data() @@ -648,14 +651,15 @@ def plot_compare( for dist_idx, dists in enumerate(data_orig["distances"]): xs = dists ys = [data["frequency"][dist_idx][band_idx][j] * unit.factor for j in range(len(dists))] - ax.plot(xs, ys, **kwargs) + ax.plot(xs, ys, **(kwargs | other_kwargs)) # add legend showing which color corresponds to which band structure if labels or (self._label and other_plotter._label): color_self, color_other = ax.lines[0].get_color(), ax.lines[-1].get_color() label_self, label_other = labels or (self._label, other_plotter._label) - ax.plot([], [], label=label_self, linewidth=3 * line_width, color=color_self) - ax.plot([], [], label=label_other, linewidth=3 * line_width, color=color_other) + ax.plot([], [], label=label_self, linewidth=2 * line_width, color=color_self) + linestyle = other_kwargs.get("linestyle", "-") + ax.plot([], [], label=label_other, linewidth=2 * line_width, color=color_other, linestyle=linestyle) ax.legend(**legend_kwargs) return ax